327 lines
11 KiB
Python
327 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
import unittest
|
|
import io
|
|
|
|
import requests
|
|
|
|
import pytest
|
|
from requests_toolbelt.multipart.encoder import (
|
|
CustomBytesIO, MultipartEncoder, FileFromURLWrapper, FileNotSupportedError)
|
|
from requests_toolbelt._compat import filepost
|
|
from . import get_betamax
|
|
|
|
|
|
preserve_bytes = {'preserve_exact_body_bytes': True}
|
|
|
|
|
|
class LargeFileMock(object):
|
|
def __init__(self):
|
|
# Let's keep track of how many bytes we've given
|
|
self.bytes_read = 0
|
|
# Our limit (1GB)
|
|
self.bytes_max = 1024 * 1024 * 1024
|
|
# Fake name
|
|
self.name = 'fake_name.py'
|
|
# Create a fileno attribute
|
|
self.fileno = None
|
|
|
|
def __len__(self):
|
|
return self.bytes_max
|
|
|
|
def read(self, size=None):
|
|
if self.bytes_read >= self.bytes_max:
|
|
return b''
|
|
|
|
if size is None:
|
|
length = self.bytes_max - self.bytes_read
|
|
else:
|
|
length = size
|
|
|
|
length = int(length)
|
|
length = min([length, self.bytes_max - self.bytes_read])
|
|
|
|
self.bytes_read += length
|
|
|
|
return b'a' * length
|
|
|
|
def tell(self):
|
|
return self.bytes_read
|
|
|
|
|
|
class TestCustomBytesIO(unittest.TestCase):
|
|
def setUp(self):
|
|
self.instance = CustomBytesIO()
|
|
|
|
def test_writable(self):
|
|
assert hasattr(self.instance, 'write')
|
|
assert self.instance.write(b'example') == 7
|
|
|
|
def test_readable(self):
|
|
assert hasattr(self.instance, 'read')
|
|
assert self.instance.read() == b''
|
|
assert self.instance.read(10) == b''
|
|
|
|
def test_can_read_after_writing_to(self):
|
|
self.instance.write(b'example text')
|
|
self.instance.read() == b'example text'
|
|
|
|
def test_can_read_some_after_writing_to(self):
|
|
self.instance.write(b'example text')
|
|
self.instance.read(6) == b'exampl'
|
|
|
|
def test_can_get_length(self):
|
|
self.instance.write(b'example')
|
|
self.instance.seek(0, 0)
|
|
assert self.instance.len == 7
|
|
|
|
def test_truncates_intelligently(self):
|
|
self.instance.write(b'abcdefghijklmnopqrstuvwxyzabcd') # 30 bytes
|
|
assert self.instance.tell() == 30
|
|
self.instance.seek(-10, 2)
|
|
self.instance.smart_truncate()
|
|
assert self.instance.len == 10
|
|
assert self.instance.read() == b'uvwxyzabcd'
|
|
assert self.instance.tell() == 10
|
|
|
|
def test_accepts_encoded_strings_with_unicode(self):
|
|
"""Accepts a string with encoded unicode characters."""
|
|
s = b'this is a unicode string: \xc3\xa9 \xc3\xa1 \xc7\xab \xc3\xb3'
|
|
self.instance = CustomBytesIO(s)
|
|
assert self.instance.read() == s
|
|
|
|
|
|
class TestFileFromURLWrapper(unittest.TestCase):
|
|
def setUp(self):
|
|
s = requests.Session()
|
|
self.recorder = get_betamax(s)
|
|
|
|
@pytest.mark.xfail
|
|
def test_read_file(self):
|
|
url = ('https://stxnext.com/static/img/logo.830ebe551641.svg')
|
|
with self.recorder.use_cassette(
|
|
'file_for_download', **preserve_bytes):
|
|
self.instance = FileFromURLWrapper(url)
|
|
assert self.instance.len == 5177
|
|
chunk = self.instance.read(20)
|
|
assert chunk == b'<svg xmlns="http://w'
|
|
assert self.instance.len == 5157
|
|
chunk = self.instance.read(0)
|
|
assert chunk == b''
|
|
assert self.instance.len == 5157
|
|
chunk = self.instance.read(10)
|
|
assert chunk == b'ww.w3.org/'
|
|
assert self.instance.len == 5147
|
|
|
|
@pytest.mark.xfail(strict=False)
|
|
def test_no_content_length_header(self):
|
|
url = (
|
|
'https://api.github.com/repos/sigmavirus24/github3.py/releases/'
|
|
'assets/37944'
|
|
)
|
|
with self.recorder.use_cassette(
|
|
'stream_response_to_file', **preserve_bytes):
|
|
with self.assertRaises(FileNotSupportedError) as context:
|
|
FileFromURLWrapper(url)
|
|
assert context.exception.__str__() == (
|
|
'Data from provided URL https://api.github.com/repos/s'
|
|
'igmavirus24/github3.py/releases/assets/37944 is not '
|
|
'supported. Lack of content-length Header in requested'
|
|
' file response.'
|
|
)
|
|
|
|
|
|
class TestMultipartEncoder(unittest.TestCase):
|
|
def setUp(self):
|
|
self.parts = [('field', 'value'), ('other_field', 'other_value')]
|
|
self.boundary = 'this-is-a-boundary'
|
|
self.instance = MultipartEncoder(self.parts, boundary=self.boundary)
|
|
|
|
def test_to_string(self):
|
|
assert self.instance.to_string() == (
|
|
'--this-is-a-boundary\r\n'
|
|
'Content-Disposition: form-data; name="field"\r\n\r\n'
|
|
'value\r\n'
|
|
'--this-is-a-boundary\r\n'
|
|
'Content-Disposition: form-data; name="other_field"\r\n\r\n'
|
|
'other_value\r\n'
|
|
'--this-is-a-boundary--\r\n'
|
|
).encode()
|
|
|
|
def test_content_type(self):
|
|
expected = 'multipart/form-data; boundary=this-is-a-boundary'
|
|
assert self.instance.content_type == expected
|
|
|
|
def test_encodes_data_the_same(self):
|
|
encoded = filepost.encode_multipart_formdata(self.parts,
|
|
self.boundary)[0]
|
|
assert encoded == self.instance.read()
|
|
|
|
def test_streams_its_data(self):
|
|
large_file = LargeFileMock()
|
|
parts = {'some field': 'value',
|
|
'some file': large_file,
|
|
}
|
|
encoder = MultipartEncoder(parts)
|
|
total_size = encoder.len
|
|
read_size = 1024 * 1024 * 128
|
|
already_read = 0
|
|
while True:
|
|
read = encoder.read(read_size)
|
|
already_read += len(read)
|
|
if not read:
|
|
break
|
|
|
|
assert encoder._buffer.tell() <= read_size
|
|
assert already_read == total_size
|
|
|
|
def test_length_is_correct(self):
|
|
encoded = filepost.encode_multipart_formdata(self.parts,
|
|
self.boundary)[0]
|
|
assert len(encoded) == self.instance.len
|
|
|
|
def test_encodes_with_readable_data(self):
|
|
s = io.BytesIO(b'value')
|
|
m = MultipartEncoder([('field', s)], boundary=self.boundary)
|
|
assert m.read() == (
|
|
'--this-is-a-boundary\r\n'
|
|
'Content-Disposition: form-data; name="field"\r\n\r\n'
|
|
'value\r\n'
|
|
'--this-is-a-boundary--\r\n'
|
|
).encode()
|
|
|
|
def test_reads_open_file_objects(self):
|
|
with open('setup.py', 'rb') as fd:
|
|
m = MultipartEncoder([('field', 'foo'), ('file', fd)])
|
|
assert m.read() is not None
|
|
|
|
@pytest.mark.xfail
|
|
def test_reads_file_from_url_wrapper(self):
|
|
s = requests.Session()
|
|
recorder = get_betamax(s)
|
|
url = ('https://stxnext.com/static/img/logo.830ebe551641.svg')
|
|
with recorder.use_cassette(
|
|
'file_for_download'):
|
|
m = MultipartEncoder(
|
|
[('field', 'foo'), ('file', FileFromURLWrapper(url))])
|
|
assert m.read() is not None
|
|
|
|
def test_reads_open_file_objects_with_a_specified_filename(self):
|
|
with open('setup.py', 'rb') as fd:
|
|
m = MultipartEncoder(
|
|
[('field', 'foo'), ('file', ('filename', fd, 'text/plain'))]
|
|
)
|
|
assert m.read() is not None
|
|
|
|
def test_reads_open_file_objects_using_to_string(self):
|
|
with open('setup.py', 'rb') as fd:
|
|
m = MultipartEncoder([('field', 'foo'), ('file', fd)])
|
|
assert m.to_string() is not None
|
|
|
|
def test_handles_encoded_unicode_strings(self):
|
|
m = MultipartEncoder([
|
|
('field',
|
|
b'this is a unicode string: \xc3\xa9 \xc3\xa1 \xc7\xab \xc3\xb3')
|
|
])
|
|
assert m.read() is not None
|
|
|
|
def test_handles_uncode_strings(self):
|
|
s = b'this is a unicode string: \xc3\xa9 \xc3\xa1 \xc7\xab \xc3\xb3'
|
|
m = MultipartEncoder([
|
|
('field', s.decode('utf-8'))
|
|
])
|
|
assert m.read() is not None
|
|
|
|
def test_regresion_1(self):
|
|
"""Ensure issue #31 doesn't ever happen again."""
|
|
fields = {
|
|
"test": "t" * 100
|
|
}
|
|
|
|
for x in range(30):
|
|
fields['f%d' % x] = (
|
|
'test', open('tests/test_multipart_encoder.py', 'rb')
|
|
)
|
|
|
|
m = MultipartEncoder(fields=fields)
|
|
total_size = m.len
|
|
|
|
blocksize = 8192
|
|
read_so_far = 0
|
|
|
|
while True:
|
|
data = m.read(blocksize)
|
|
if not data:
|
|
break
|
|
read_so_far += len(data)
|
|
|
|
assert read_so_far == total_size
|
|
|
|
def test_regression_2(self):
|
|
"""Ensure issue #31 doesn't ever happen again."""
|
|
fields = {
|
|
"test": "t" * 8100
|
|
}
|
|
|
|
m = MultipartEncoder(fields=fields)
|
|
total_size = m.len
|
|
|
|
blocksize = 8192
|
|
read_so_far = 0
|
|
|
|
while True:
|
|
data = m.read(blocksize)
|
|
if not data:
|
|
break
|
|
read_so_far += len(data)
|
|
|
|
assert read_so_far == total_size
|
|
|
|
def test_handles_empty_unicode_values(self):
|
|
"""Verify that the Encoder can handle empty unicode strings.
|
|
|
|
See https://github.com/requests/toolbelt/issues/46 for
|
|
more context.
|
|
"""
|
|
fields = [(b'test'.decode('utf-8'), b''.decode('utf-8'))]
|
|
m = MultipartEncoder(fields=fields)
|
|
assert len(m.read()) > 0
|
|
|
|
def test_accepts_custom_content_type(self):
|
|
"""Verify that the Encoder handles custom content-types.
|
|
|
|
See https://github.com/requests/toolbelt/issues/52
|
|
"""
|
|
fields = [
|
|
(b'test'.decode('utf-8'), (b'filename'.decode('utf-8'),
|
|
b'filecontent',
|
|
b'application/json'.decode('utf-8')))
|
|
]
|
|
m = MultipartEncoder(fields=fields)
|
|
output = m.read().decode('utf-8')
|
|
assert output.index('Content-Type: application/json\r\n') > 0
|
|
|
|
def test_accepts_custom_headers(self):
|
|
"""Verify that the Encoder handles custom headers.
|
|
|
|
See https://github.com/requests/toolbelt/issues/52
|
|
"""
|
|
fields = [
|
|
(b'test'.decode('utf-8'), (b'filename'.decode('utf-8'),
|
|
b'filecontent',
|
|
b'application/json'.decode('utf-8'),
|
|
{'X-My-Header': 'my-value'}))
|
|
]
|
|
m = MultipartEncoder(fields=fields)
|
|
output = m.read().decode('utf-8')
|
|
assert output.index('X-My-Header: my-value\r\n') > 0
|
|
|
|
def test_no_parts(self):
|
|
fields = []
|
|
boundary = '--90967316f8404798963cce746a4f4ef9'
|
|
m = MultipartEncoder(fields=fields, boundary=boundary)
|
|
output = m.read().decode('utf-8')
|
|
assert output == '----90967316f8404798963cce746a4f4ef9--\r\n'
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|