python-oauthlib/tests/oauth1/rfc5849/endpoints/test_base.py

407 lines
16 KiB
Python

from re import sub
from unittest.mock import MagicMock
from oauthlib.common import CaseInsensitiveDict, safe_string_equals
from oauthlib.oauth1 import Client, RequestValidator
from oauthlib.oauth1.rfc5849 import (
SIGNATURE_HMAC, SIGNATURE_PLAINTEXT, SIGNATURE_RSA, errors,
)
from oauthlib.oauth1.rfc5849.endpoints import (
BaseEndpoint, RequestTokenEndpoint,
)
from tests.unittest import TestCase
URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"}
class BaseEndpointTest(TestCase):
def setUp(self):
self.validator = MagicMock(spec=RequestValidator)
self.validator.allowed_signature_methods = ['HMAC-SHA1']
self.validator.timestamp_lifetime = 600
self.endpoint = RequestTokenEndpoint(self.validator)
self.client = Client('foo', callback_uri='https://c.b/cb')
self.uri, self.headers, self.body = self.client.sign(
'https://i.b/request_token')
def test_ssl_enforcement(self):
uri, headers, _ = self.client.sign('http://i.b/request_token')
h, b, s = self.endpoint.create_request_token_response(
uri, headers=headers)
self.assertEqual(s, 400)
self.assertIn('insecure_transport_protocol', b)
def test_missing_parameters(self):
h, b, s = self.endpoint.create_request_token_response(self.uri)
self.assertEqual(s, 400)
self.assertIn('invalid_request', b)
def test_signature_methods(self):
headers = {}
headers['Authorization'] = self.headers['Authorization'].replace(
'HMAC', 'RSA')
h, b, s = self.endpoint.create_request_token_response(
self.uri, headers=headers)
self.assertEqual(s, 400)
self.assertIn('invalid_signature_method', b)
def test_invalid_version(self):
headers = {}
headers['Authorization'] = self.headers['Authorization'].replace(
'1.0', '2.0')
h, b, s = self.endpoint.create_request_token_response(
self.uri, headers=headers)
self.assertEqual(s, 400)
self.assertIn('invalid_request', b)
def test_expired_timestamp(self):
headers = {}
for pattern in ('12345678901', '4567890123', '123456789K'):
headers['Authorization'] = sub(r'timestamp="\d*k?"',
'timestamp="%s"' % pattern,
self.headers['Authorization'])
h, b, s = self.endpoint.create_request_token_response(
self.uri, headers=headers)
self.assertEqual(s, 400)
self.assertIn('invalid_request', b)
def test_client_key_check(self):
self.validator.check_client_key.return_value = False
h, b, s = self.endpoint.create_request_token_response(
self.uri, headers=self.headers)
self.assertEqual(s, 400)
self.assertIn('invalid_request', b)
def test_noncecheck(self):
self.validator.check_nonce.return_value = False
h, b, s = self.endpoint.create_request_token_response(
self.uri, headers=self.headers)
self.assertEqual(s, 400)
self.assertIn('invalid_request', b)
def test_enforce_ssl(self):
"""Ensure SSL is enforced by default."""
v = RequestValidator()
e = BaseEndpoint(v)
c = Client('foo')
u, h, b = c.sign('http://example.com')
r = e._create_request(u, 'GET', b, h)
self.assertRaises(errors.InsecureTransportError,
e._check_transport_security, r)
def test_multiple_source_params(self):
"""Check for duplicate params"""
v = RequestValidator()
e = BaseEndpoint(v)
self.assertRaises(errors.InvalidRequestError, e._create_request,
'https://a.b/?oauth_signature_method=HMAC-SHA1',
'GET', 'oauth_version=foo', URLENCODED)
headers = {'Authorization': 'OAuth oauth_signature="foo"'}
headers.update(URLENCODED)
self.assertRaises(errors.InvalidRequestError, e._create_request,
'https://a.b/?oauth_signature_method=HMAC-SHA1',
'GET',
'oauth_version=foo',
headers)
headers = {'Authorization': 'OAuth oauth_signature_method="foo"'}
headers.update(URLENCODED)
self.assertRaises(errors.InvalidRequestError, e._create_request,
'https://a.b/',
'GET',
'oauth_signature=foo',
headers)
def test_duplicate_params(self):
"""Ensure params are only supplied once"""
v = RequestValidator()
e = BaseEndpoint(v)
self.assertRaises(errors.InvalidRequestError, e._create_request,
'https://a.b/?oauth_version=a&oauth_version=b',
'GET', None, URLENCODED)
self.assertRaises(errors.InvalidRequestError, e._create_request,
'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b',
URLENCODED)
def test_mandated_params(self):
"""Ensure all mandatory params are present."""
v = RequestValidator()
e = BaseEndpoint(v)
r = e._create_request('https://a.b/', 'GET',
'oauth_signature=a&oauth_consumer_key=b&oauth_nonce',
URLENCODED)
self.assertRaises(errors.InvalidRequestError,
e._check_mandatory_parameters, r)
def test_oauth_version(self):
"""OAuth version must be 1.0 if present."""
v = RequestValidator()
e = BaseEndpoint(v)
r = e._create_request('https://a.b/', 'GET',
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&'
'oauth_version=2.0'),
URLENCODED)
self.assertRaises(errors.InvalidRequestError,
e._check_mandatory_parameters, r)
def test_oauth_timestamp(self):
"""Check for a valid UNIX timestamp."""
v = RequestValidator()
e = BaseEndpoint(v)
# Invalid timestamp length, must be 10
r = e._create_request('https://a.b/', 'GET',
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
'oauth_timestamp=123456789'),
URLENCODED)
self.assertRaises(errors.InvalidRequestError,
e._check_mandatory_parameters, r)
# Invalid timestamp age, must be younger than 10 minutes
r = e._create_request('https://a.b/', 'GET',
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
'oauth_timestamp=1234567890'),
URLENCODED)
self.assertRaises(errors.InvalidRequestError,
e._check_mandatory_parameters, r)
# Timestamp must be an integer
r = e._create_request('https://a.b/', 'GET',
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
'oauth_timestamp=123456789a'),
URLENCODED)
self.assertRaises(errors.InvalidRequestError,
e._check_mandatory_parameters, r)
def test_case_insensitive_headers(self):
"""Ensure headers are case-insensitive"""
v = RequestValidator()
e = BaseEndpoint(v)
r = e._create_request('https://a.b', 'POST',
('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
'oauth_timestamp=123456789a'),
URLENCODED)
self.assertIsInstance(r.headers, CaseInsensitiveDict)
def test_signature_method_validation(self):
"""Ensure valid signature method is used."""
body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
'oauth_version=1.0&oauth_signature_method=%s&'
'oauth_timestamp=1234567890')
uri = 'https://example.com/'
class HMACValidator(RequestValidator):
@property
def allowed_signature_methods(self):
return (SIGNATURE_HMAC,)
v = HMACValidator()
e = BaseEndpoint(v)
r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
class RSAValidator(RequestValidator):
@property
def allowed_signature_methods(self):
return (SIGNATURE_RSA,)
v = RSAValidator()
e = BaseEndpoint(v)
r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
class PlainValidator(RequestValidator):
@property
def allowed_signature_methods(self):
return (SIGNATURE_PLAINTEXT,)
v = PlainValidator()
e = BaseEndpoint(v)
r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
self.assertRaises(errors.InvalidSignatureMethodError,
e._check_mandatory_parameters, r)
class ClientValidator(RequestValidator):
clients = ['foo']
nonces = [('foo', 'once', '1234567891', 'fez')]
owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']}
assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'}
verifiers = {('foo', 'fez'): 'shibboleth'}
@property
def client_key_length(self):
return 1, 30
@property
def request_token_length(self):
return 1, 30
@property
def access_token_length(self):
return 1, 30
@property
def nonce_length(self):
return 2, 30
@property
def verifier_length(self):
return 2, 30
@property
def realms(self):
return ['photos']
@property
def timestamp_lifetime(self):
# Disabled check to allow hardcoded verification signatures
return 1000000000
@property
def dummy_client(self):
return 'dummy'
@property
def dummy_request_token(self):
return 'dumbo'
@property
def dummy_access_token(self):
return 'dumbo'
def validate_timestamp_and_nonce(self, client_key, timestamp, nonce,
request, request_token=None, access_token=None):
resource_owner_key = request_token if request_token else access_token
return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces
def validate_client_key(self, client_key):
return client_key in self.clients
def validate_access_token(self, client_key, access_token, request):
return (self.owners.get(client_key) and
access_token in self.owners.get(client_key))
def validate_request_token(self, client_key, request_token, request):
return (self.owners.get(client_key) and
request_token in self.owners.get(client_key))
def validate_requested_realm(self, client_key, realm, request):
return True
def validate_realm(self, client_key, access_token, request, uri=None,
required_realm=None):
return (client_key, access_token) in self.assigned_realms
def validate_verifier(self, client_key, request_token, verifier,
request):
return ((client_key, request_token) in self.verifiers and
safe_string_equals(verifier, self.verifiers.get(
(client_key, request_token))))
def validate_redirect_uri(self, client_key, redirect_uri, request):
return redirect_uri.startswith('http://client.example.com/')
def get_client_secret(self, client_key, request):
return 'super secret'
def get_access_token_secret(self, client_key, access_token, request):
return 'even more secret'
def get_request_token_secret(self, client_key, request_token, request):
return 'even more secret'
def get_rsa_key(self, client_key, request):
return ("-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNA"
"DCBiQKBgQDVLQCATX8iK+aZuGVdkGb6uiar\nLi/jqFwL1dYj0JLIsdQc"
"KaMWtPC06K0+vI+RRZcjKc6sNB9/7kJcKN9Ekc9BUxyT\n/D09Cz47cmC"
"YsUoiW7G8NSqbE4wPiVpGkJRzFAxaCWwOSSQ+lpC9vwxnvVQfOoZ1\nnp"
"mWbCdA0iTxsMahwQIDAQAB\n-----END PUBLIC KEY-----")
class SignatureVerificationTest(TestCase):
def setUp(self):
v = ClientValidator()
self.e = BaseEndpoint(v)
self.uri = 'https://example.com/'
self.sig = ('oauth_signature=%s&'
'oauth_timestamp=1234567890&'
'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
'oauth_version=1.0&'
'oauth_signature_method=%s&'
'oauth_token=abcdefghijklmnopqrstuvxyz&'
'oauth_consumer_key=foo')
def test_signature_too_short(self):
short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&'
'oauth_timestamp=1234567890&'
'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&'
'oauth_token=abcdefghijklmnopqrstuvxyz&'
'oauth_consumer_key=foo')
r = self.e._create_request(self.uri, 'GET', short_sig, URLENCODED)
self.assertFalse(self.e._check_signature(r))
plain = ('oauth_signature=correctlengthbutthewrongcontent1111&'
'oauth_timestamp=1234567890&'
'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
'oauth_version=1.0&oauth_signature_method=PLAINTEXT&'
'oauth_token=abcdefghijklmnopqrstuvxyz&'
'oauth_consumer_key=foo')
r = self.e._create_request(self.uri, 'GET', plain, URLENCODED)
self.assertFalse(self.e._check_signature(r))
def test_hmac_signature(self):
hmac_sig = "fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY%3D"
sig = self.sig % (hmac_sig, "HMAC-SHA1")
r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
self.assertTrue(self.e._check_signature(r))
def test_rsa_signature(self):
rsa_sig = ("fxFvCx33oKlR9wDquJ%2FPsndFzJphyBa3RFPPIKi3flqK%2BJ7yIrMVbH"
"YTM%2FLHPc7NChWz4F4%2FzRA%2BDN1k08xgYGSBoWJUOW6VvOQ6fbYhMA"
"FkOGYbuGDbje487XMzsAcv6ZjqZHCROSCk5vofgLk2SN7RZ3OrgrFzf4in"
"xetClqA%3D")
sig = self.sig % (rsa_sig, "RSA-SHA1")
r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
self.assertTrue(self.e._check_signature(r))
def test_plaintext_signature(self):
plain_sig = "super%252520secret%26even%252520more%252520secret"
sig = self.sig % (plain_sig, "PLAINTEXT")
r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
self.assertTrue(self.e._check_signature(r))