python-requests-toolbelt/tests/test_downloadutils.py

224 lines
7.2 KiB
Python
Raw Normal View History

2023-02-07 14:29:03 +08:00
"""Tests for the utils module."""
import io
import os
import os.path
import shutil
import tempfile
import requests
from requests_toolbelt.downloadutils import stream
from requests_toolbelt.downloadutils import tee
try:
from unittest import mock
except ImportError:
import mock
import pytest
from . import get_betamax
preserve_bytes = {'preserve_exact_body_bytes': True}
def test_get_download_file_path_uses_content_disposition():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
filename = 'github3.py-0.7.1-py2.py3-none-any.whl'
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'})
path = stream.get_download_file_path(r, None)
r.close()
assert path == filename
def test_get_download_file_path_directory():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
filename = 'github3.py-0.7.1-py2.py3-none-any.whl'
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'})
path = stream.get_download_file_path(r, tempfile.tempdir)
r.close()
assert path == os.path.join(tempfile.tempdir, filename)
def test_get_download_file_path_specific_file():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'})
path = stream.get_download_file_path(r, '/arbitrary/file.path')
r.close()
assert path == '/arbitrary/file.path'
def test_stream_response_to_file_uses_content_disposition():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
filename = 'github3.py-0.7.1-py2.py3-none-any.whl'
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'},
stream=True)
stream.stream_response_to_file(r)
assert os.path.exists(filename)
os.unlink(filename)
def test_stream_response_to_specific_filename():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
filename = 'github3.py.whl'
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'},
stream=True)
stream.stream_response_to_file(r, path=filename)
assert os.path.exists(filename)
os.unlink(filename)
def test_stream_response_to_directory():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
td = tempfile.mkdtemp()
try:
filename = 'github3.py-0.7.1-py2.py3-none-any.whl'
expected_path = os.path.join(td, filename)
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'},
stream=True)
stream.stream_response_to_file(r, path=td)
assert os.path.exists(expected_path)
finally:
shutil.rmtree(td)
def test_stream_response_to_existing_file():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
filename = 'github3.py.whl'
with open(filename, 'w') as f_existing:
f_existing.write('test')
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'},
stream=True)
try:
stream.stream_response_to_file(r, path=filename)
except stream.exc.StreamingError as e:
assert str(e).startswith('File already exists:')
else:
assert False, "Should have raised a FileExistsError"
finally:
os.unlink(filename)
def test_stream_response_to_file_like_object():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
file_obj = io.BytesIO()
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'},
stream=True)
stream.stream_response_to_file(r, path=file_obj)
assert 0 < file_obj.tell()
def test_stream_response_to_file_chunksize():
s = requests.Session()
recorder = get_betamax(s)
url = ('https://api.github.com/repos/sigmavirus24/github3.py/releases/'
'assets/37944')
class FileWrapper(io.BytesIO):
def __init__(self):
super(FileWrapper, self).__init__()
self.chunk_sizes = []
def write(self, data):
self.chunk_sizes.append(len(data))
return super(FileWrapper, self).write(data)
file_obj = FileWrapper()
chunksize = 1231
with recorder.use_cassette('stream_response_to_file', **preserve_bytes):
r = s.get(url, headers={'Accept': 'application/octet-stream'},
stream=True)
stream.stream_response_to_file(r, path=file_obj, chunksize=chunksize)
assert 0 < file_obj.tell()
assert len(file_obj.chunk_sizes) >= 1
assert file_obj.chunk_sizes[0] == chunksize
@pytest.fixture
def streamed_response(chunks=None):
chunks = chunks or [b'chunk'] * 8
response = mock.MagicMock()
response.raw.stream.return_value = chunks
return response
def test_tee(streamed_response):
response = streamed_response
expected_len = len('chunk') * 8
fileobject = io.BytesIO()
assert expected_len == sum(len(c) for c in tee.tee(response, fileobject))
assert fileobject.getvalue() == b'chunkchunkchunkchunkchunkchunkchunkchunk'
def test_tee_rejects_StringIO():
fileobject = io.StringIO()
with pytest.raises(TypeError):
# The generator needs to be iterated over before the exception will be
# raised
sum(len(c) for c in tee.tee(None, fileobject))
def test_tee_to_file(streamed_response):
response = streamed_response
expected_len = len('chunk') * 8
assert expected_len == sum(
len(c) for c in tee.tee_to_file(response, 'tee.txt')
)
assert os.path.exists('tee.txt')
os.remove('tee.txt')
def test_tee_to_bytearray(streamed_response):
response = streamed_response
arr = bytearray()
expected_arr = bytearray(b'chunk' * 8)
expected_len = len(expected_arr)
assert expected_len == sum(
len(c) for c in tee.tee_to_bytearray(response, arr)
)
assert expected_arr == arr
def test_tee_to_bytearray_only_accepts_bytearrays():
with pytest.raises(TypeError):
tee.tee_to_bytearray(None, object())