407 lines
16 KiB
Python
407 lines
16 KiB
Python
from re import sub
|
|
from unittest.mock import MagicMock
|
|
|
|
from oauthlib.common import CaseInsensitiveDict, safe_string_equals
|
|
from oauthlib.oauth1 import Client, RequestValidator
|
|
from oauthlib.oauth1.rfc5849 import (
|
|
SIGNATURE_HMAC, SIGNATURE_PLAINTEXT, SIGNATURE_RSA, errors,
|
|
)
|
|
from oauthlib.oauth1.rfc5849.endpoints import (
|
|
BaseEndpoint, RequestTokenEndpoint,
|
|
)
|
|
|
|
from tests.unittest import TestCase
|
|
|
|
URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"}
|
|
|
|
|
|
class BaseEndpointTest(TestCase):
|
|
|
|
def setUp(self):
|
|
self.validator = MagicMock(spec=RequestValidator)
|
|
self.validator.allowed_signature_methods = ['HMAC-SHA1']
|
|
self.validator.timestamp_lifetime = 600
|
|
self.endpoint = RequestTokenEndpoint(self.validator)
|
|
self.client = Client('foo', callback_uri='https://c.b/cb')
|
|
self.uri, self.headers, self.body = self.client.sign(
|
|
'https://i.b/request_token')
|
|
|
|
def test_ssl_enforcement(self):
|
|
uri, headers, _ = self.client.sign('http://i.b/request_token')
|
|
h, b, s = self.endpoint.create_request_token_response(
|
|
uri, headers=headers)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('insecure_transport_protocol', b)
|
|
|
|
def test_missing_parameters(self):
|
|
h, b, s = self.endpoint.create_request_token_response(self.uri)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('invalid_request', b)
|
|
|
|
def test_signature_methods(self):
|
|
headers = {}
|
|
headers['Authorization'] = self.headers['Authorization'].replace(
|
|
'HMAC', 'RSA')
|
|
h, b, s = self.endpoint.create_request_token_response(
|
|
self.uri, headers=headers)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('invalid_signature_method', b)
|
|
|
|
def test_invalid_version(self):
|
|
headers = {}
|
|
headers['Authorization'] = self.headers['Authorization'].replace(
|
|
'1.0', '2.0')
|
|
h, b, s = self.endpoint.create_request_token_response(
|
|
self.uri, headers=headers)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('invalid_request', b)
|
|
|
|
def test_expired_timestamp(self):
|
|
headers = {}
|
|
for pattern in ('12345678901', '4567890123', '123456789K'):
|
|
headers['Authorization'] = sub(r'timestamp="\d*k?"',
|
|
'timestamp="%s"' % pattern,
|
|
self.headers['Authorization'])
|
|
h, b, s = self.endpoint.create_request_token_response(
|
|
self.uri, headers=headers)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('invalid_request', b)
|
|
|
|
def test_client_key_check(self):
|
|
self.validator.check_client_key.return_value = False
|
|
h, b, s = self.endpoint.create_request_token_response(
|
|
self.uri, headers=self.headers)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('invalid_request', b)
|
|
|
|
def test_noncecheck(self):
|
|
self.validator.check_nonce.return_value = False
|
|
h, b, s = self.endpoint.create_request_token_response(
|
|
self.uri, headers=self.headers)
|
|
self.assertEqual(s, 400)
|
|
self.assertIn('invalid_request', b)
|
|
|
|
def test_enforce_ssl(self):
|
|
"""Ensure SSL is enforced by default."""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
c = Client('foo')
|
|
u, h, b = c.sign('http://example.com')
|
|
r = e._create_request(u, 'GET', b, h)
|
|
self.assertRaises(errors.InsecureTransportError,
|
|
e._check_transport_security, r)
|
|
|
|
def test_multiple_source_params(self):
|
|
"""Check for duplicate params"""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
self.assertRaises(errors.InvalidRequestError, e._create_request,
|
|
'https://a.b/?oauth_signature_method=HMAC-SHA1',
|
|
'GET', 'oauth_version=foo', URLENCODED)
|
|
headers = {'Authorization': 'OAuth oauth_signature="foo"'}
|
|
headers.update(URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError, e._create_request,
|
|
'https://a.b/?oauth_signature_method=HMAC-SHA1',
|
|
'GET',
|
|
'oauth_version=foo',
|
|
headers)
|
|
headers = {'Authorization': 'OAuth oauth_signature_method="foo"'}
|
|
headers.update(URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError, e._create_request,
|
|
'https://a.b/',
|
|
'GET',
|
|
'oauth_signature=foo',
|
|
headers)
|
|
|
|
def test_duplicate_params(self):
|
|
"""Ensure params are only supplied once"""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
self.assertRaises(errors.InvalidRequestError, e._create_request,
|
|
'https://a.b/?oauth_version=a&oauth_version=b',
|
|
'GET', None, URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError, e._create_request,
|
|
'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b',
|
|
URLENCODED)
|
|
|
|
def test_mandated_params(self):
|
|
"""Ensure all mandatory params are present."""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
r = e._create_request('https://a.b/', 'GET',
|
|
'oauth_signature=a&oauth_consumer_key=b&oauth_nonce',
|
|
URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
def test_oauth_version(self):
|
|
"""OAuth version must be 1.0 if present."""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
r = e._create_request('https://a.b/', 'GET',
|
|
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
|
|
'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&'
|
|
'oauth_version=2.0'),
|
|
URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
def test_oauth_timestamp(self):
|
|
"""Check for a valid UNIX timestamp."""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
|
|
# Invalid timestamp length, must be 10
|
|
r = e._create_request('https://a.b/', 'GET',
|
|
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
|
|
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
|
|
'oauth_timestamp=123456789'),
|
|
URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
# Invalid timestamp age, must be younger than 10 minutes
|
|
r = e._create_request('https://a.b/', 'GET',
|
|
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
|
|
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
|
|
'oauth_timestamp=1234567890'),
|
|
URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
# Timestamp must be an integer
|
|
r = e._create_request('https://a.b/', 'GET',
|
|
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
|
|
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
|
|
'oauth_timestamp=123456789a'),
|
|
URLENCODED)
|
|
self.assertRaises(errors.InvalidRequestError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
def test_case_insensitive_headers(self):
|
|
"""Ensure headers are case-insensitive"""
|
|
v = RequestValidator()
|
|
e = BaseEndpoint(v)
|
|
r = e._create_request('https://a.b', 'POST',
|
|
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
|
|
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
|
|
'oauth_timestamp=123456789a'),
|
|
URLENCODED)
|
|
self.assertIsInstance(r.headers, CaseInsensitiveDict)
|
|
|
|
def test_signature_method_validation(self):
|
|
"""Ensure valid signature method is used."""
|
|
|
|
body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
|
|
'oauth_version=1.0&oauth_signature_method=%s&'
|
|
'oauth_timestamp=1234567890')
|
|
|
|
uri = 'https://example.com/'
|
|
|
|
class HMACValidator(RequestValidator):
|
|
|
|
@property
|
|
def allowed_signature_methods(self):
|
|
return (SIGNATURE_HMAC,)
|
|
|
|
v = HMACValidator()
|
|
e = BaseEndpoint(v)
|
|
r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
class RSAValidator(RequestValidator):
|
|
|
|
@property
|
|
def allowed_signature_methods(self):
|
|
return (SIGNATURE_RSA,)
|
|
|
|
v = RSAValidator()
|
|
e = BaseEndpoint(v)
|
|
r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
class PlainValidator(RequestValidator):
|
|
|
|
@property
|
|
def allowed_signature_methods(self):
|
|
return (SIGNATURE_PLAINTEXT,)
|
|
|
|
v = PlainValidator()
|
|
e = BaseEndpoint(v)
|
|
r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
|
|
self.assertRaises(errors.InvalidSignatureMethodError,
|
|
e._check_mandatory_parameters, r)
|
|
|
|
|
|
class ClientValidator(RequestValidator):
|
|
clients = ['foo']
|
|
nonces = [('foo', 'once', '1234567891', 'fez')]
|
|
owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']}
|
|
assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'}
|
|
verifiers = {('foo', 'fez'): 'shibboleth'}
|
|
|
|
@property
|
|
def client_key_length(self):
|
|
return 1, 30
|
|
|
|
@property
|
|
def request_token_length(self):
|
|
return 1, 30
|
|
|
|
@property
|
|
def access_token_length(self):
|
|
return 1, 30
|
|
|
|
@property
|
|
def nonce_length(self):
|
|
return 2, 30
|
|
|
|
@property
|
|
def verifier_length(self):
|
|
return 2, 30
|
|
|
|
@property
|
|
def realms(self):
|
|
return ['photos']
|
|
|
|
@property
|
|
def timestamp_lifetime(self):
|
|
# Disabled check to allow hardcoded verification signatures
|
|
return 1000000000
|
|
|
|
@property
|
|
def dummy_client(self):
|
|
return 'dummy'
|
|
|
|
@property
|
|
def dummy_request_token(self):
|
|
return 'dumbo'
|
|
|
|
@property
|
|
def dummy_access_token(self):
|
|
return 'dumbo'
|
|
|
|
def validate_timestamp_and_nonce(self, client_key, timestamp, nonce,
|
|
request, request_token=None, access_token=None):
|
|
resource_owner_key = request_token if request_token else access_token
|
|
return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces
|
|
|
|
def validate_client_key(self, client_key):
|
|
return client_key in self.clients
|
|
|
|
def validate_access_token(self, client_key, access_token, request):
|
|
return (self.owners.get(client_key) and
|
|
access_token in self.owners.get(client_key))
|
|
|
|
def validate_request_token(self, client_key, request_token, request):
|
|
return (self.owners.get(client_key) and
|
|
request_token in self.owners.get(client_key))
|
|
|
|
def validate_requested_realm(self, client_key, realm, request):
|
|
return True
|
|
|
|
def validate_realm(self, client_key, access_token, request, uri=None,
|
|
required_realm=None):
|
|
return (client_key, access_token) in self.assigned_realms
|
|
|
|
def validate_verifier(self, client_key, request_token, verifier,
|
|
request):
|
|
return ((client_key, request_token) in self.verifiers and
|
|
safe_string_equals(verifier, self.verifiers.get(
|
|
(client_key, request_token))))
|
|
|
|
def validate_redirect_uri(self, client_key, redirect_uri, request):
|
|
return redirect_uri.startswith('http://client.example.com/')
|
|
|
|
def get_client_secret(self, client_key, request):
|
|
return 'super secret'
|
|
|
|
def get_access_token_secret(self, client_key, access_token, request):
|
|
return 'even more secret'
|
|
|
|
def get_request_token_secret(self, client_key, request_token, request):
|
|
return 'even more secret'
|
|
|
|
def get_rsa_key(self, client_key, request):
|
|
return ("-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNA"
|
|
"DCBiQKBgQDVLQCATX8iK+aZuGVdkGb6uiar\nLi/jqFwL1dYj0JLIsdQc"
|
|
"KaMWtPC06K0+vI+RRZcjKc6sNB9/7kJcKN9Ekc9BUxyT\n/D09Cz47cmC"
|
|
"YsUoiW7G8NSqbE4wPiVpGkJRzFAxaCWwOSSQ+lpC9vwxnvVQfOoZ1\nnp"
|
|
"mWbCdA0iTxsMahwQIDAQAB\n-----END PUBLIC KEY-----")
|
|
|
|
|
|
class SignatureVerificationTest(TestCase):
|
|
|
|
def setUp(self):
|
|
v = ClientValidator()
|
|
self.e = BaseEndpoint(v)
|
|
|
|
self.uri = 'https://example.com/'
|
|
self.sig = ('oauth_signature=%s&'
|
|
'oauth_timestamp=1234567890&'
|
|
'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
|
|
'oauth_version=1.0&'
|
|
'oauth_signature_method=%s&'
|
|
'oauth_token=abcdefghijklmnopqrstuvxyz&'
|
|
'oauth_consumer_key=foo')
|
|
|
|
def test_signature_too_short(self):
|
|
short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&'
|
|
'oauth_timestamp=1234567890&'
|
|
'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
|
|
'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&'
|
|
'oauth_token=abcdefghijklmnopqrstuvxyz&'
|
|
'oauth_consumer_key=foo')
|
|
r = self.e._create_request(self.uri, 'GET', short_sig, URLENCODED)
|
|
self.assertFalse(self.e._check_signature(r))
|
|
|
|
plain = ('oauth_signature=correctlengthbutthewrongcontent1111&'
|
|
'oauth_timestamp=1234567890&'
|
|
'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
|
|
'oauth_version=1.0&oauth_signature_method=PLAINTEXT&'
|
|
'oauth_token=abcdefghijklmnopqrstuvxyz&'
|
|
'oauth_consumer_key=foo')
|
|
r = self.e._create_request(self.uri, 'GET', plain, URLENCODED)
|
|
self.assertFalse(self.e._check_signature(r))
|
|
|
|
def test_hmac_signature(self):
|
|
hmac_sig = "fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY%3D"
|
|
sig = self.sig % (hmac_sig, "HMAC-SHA1")
|
|
r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
|
|
self.assertTrue(self.e._check_signature(r))
|
|
|
|
def test_rsa_signature(self):
|
|
rsa_sig = ("fxFvCx33oKlR9wDquJ%2FPsndFzJphyBa3RFPPIKi3flqK%2BJ7yIrMVbH"
|
|
"YTM%2FLHPc7NChWz4F4%2FzRA%2BDN1k08xgYGSBoWJUOW6VvOQ6fbYhMA"
|
|
"FkOGYbuGDbje487XMzsAcv6ZjqZHCROSCk5vofgLk2SN7RZ3OrgrFzf4in"
|
|
"xetClqA%3D")
|
|
sig = self.sig % (rsa_sig, "RSA-SHA1")
|
|
r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
|
|
self.assertTrue(self.e._check_signature(r))
|
|
|
|
def test_plaintext_signature(self):
|
|
plain_sig = "super%252520secret%26even%252520more%252520secret"
|
|
sig = self.sig % (plain_sig, "PLAINTEXT")
|
|
r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
|
|
self.assertTrue(self.e._check_signature(r))
|