"""Tests for the utils module.""" import io import os import os.path import shutil import tempfile import requests from requests_toolbelt.downloadutils import stream from requests_toolbelt.downloadutils import tee try: from unittest import mock except ImportError: import mock import pytest from . import get_betamax preserve_bytes = {'preserve_exact_body_bytes': True} def test_get_download_file_path_uses_content_disposition(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') filename = 'github3.py-0.7.1-py2.py3-none-any.whl' with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}) path = stream.get_download_file_path(r, None) r.close() assert path == filename def test_get_download_file_path_directory(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') filename = 'github3.py-0.7.1-py2.py3-none-any.whl' with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}) path = stream.get_download_file_path(r, tempfile.tempdir) r.close() assert path == os.path.join(tempfile.tempdir, filename) def test_get_download_file_path_specific_file(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}) path = stream.get_download_file_path(r, '/arbitrary/file.path') r.close() assert path == '/arbitrary/file.path' def test_stream_response_to_file_uses_content_disposition(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') filename = 'github3.py-0.7.1-py2.py3-none-any.whl' with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}, stream=True) stream.stream_response_to_file(r) assert os.path.exists(filename) os.unlink(filename) def test_stream_response_to_specific_filename(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') filename = 'github3.py.whl' with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}, stream=True) stream.stream_response_to_file(r, path=filename) assert os.path.exists(filename) os.unlink(filename) def test_stream_response_to_directory(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') td = tempfile.mkdtemp() try: filename = 'github3.py-0.7.1-py2.py3-none-any.whl' expected_path = os.path.join(td, filename) with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}, stream=True) stream.stream_response_to_file(r, path=td) assert os.path.exists(expected_path) finally: shutil.rmtree(td) def test_stream_response_to_existing_file(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') filename = 'github3.py.whl' with open(filename, 'w') as f_existing: f_existing.write('test') with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}, stream=True) try: stream.stream_response_to_file(r, path=filename) except stream.exc.StreamingError as e: assert str(e).startswith('File already exists:') else: assert False, "Should have raised a FileExistsError" finally: os.unlink(filename) def test_stream_response_to_file_like_object(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') file_obj = io.BytesIO() with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}, stream=True) stream.stream_response_to_file(r, path=file_obj) assert 0 < file_obj.tell() def test_stream_response_to_file_chunksize(): s = requests.Session() recorder = get_betamax(s) url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/' 'assets/37944') class FileWrapper(io.BytesIO): def __init__(self): super(FileWrapper, self).__init__() self.chunk_sizes = [] def write(self, data): self.chunk_sizes.append(len(data)) return super(FileWrapper, self).write(data) file_obj = FileWrapper() chunksize = 1231 with recorder.use_cassette('stream_response_to_file', **preserve_bytes): r = s.get(url, headers={'Accept': 'application/octet-stream'}, stream=True) stream.stream_response_to_file(r, path=file_obj, chunksize=chunksize) assert 0 < file_obj.tell() assert len(file_obj.chunk_sizes) >= 1 assert file_obj.chunk_sizes[0] == chunksize @pytest.fixture def streamed_response(chunks=None): chunks = chunks or [b'chunk'] * 8 response = mock.MagicMock() response.raw.stream.return_value = chunks return response def test_tee(streamed_response): response = streamed_response expected_len = len('chunk') * 8 fileobject = io.BytesIO() assert expected_len == sum(len(c) for c in tee.tee(response, fileobject)) assert fileobject.getvalue() == b'chunkchunkchunkchunkchunkchunkchunkchunk' def test_tee_rejects_StringIO(): fileobject = io.StringIO() with pytest.raises(TypeError): # The generator needs to be iterated over before the exception will be # raised sum(len(c) for c in tee.tee(None, fileobject)) def test_tee_to_file(streamed_response): response = streamed_response expected_len = len('chunk') * 8 assert expected_len == sum( len(c) for c in tee.tee_to_file(response, 'tee.txt') ) assert os.path.exists('tee.txt') os.remove('tee.txt') def test_tee_to_bytearray(streamed_response): response = streamed_response arr = bytearray() expected_arr = bytearray(b'chunk' * 8) expected_len = len(expected_arr) assert expected_len == sum( len(c) for c in tee.tee_to_bytearray(response, arr) ) assert expected_arr == arr def test_tee_to_bytearray_only_accepts_bytearrays(): with pytest.raises(TypeError): tee.tee_to_bytearray(None, object())