379 lines
16 KiB
Python
379 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
|
import datetime
|
|
import json
|
|
from unittest.mock import patch
|
|
|
|
from oauthlib import common
|
|
from oauthlib.oauth2 import Client, InsecureTransportError, TokenExpiredError
|
|
from oauthlib.oauth2.rfc6749 import utils
|
|
from oauthlib.oauth2.rfc6749.clients import AUTH_HEADER, BODY, URI_QUERY
|
|
|
|
from tests.unittest import TestCase
|
|
|
|
|
|
class ClientTest(TestCase):
|
|
|
|
client_id = "someclientid"
|
|
uri = "https://example.com/path?query=world"
|
|
body = "not=empty"
|
|
headers = {}
|
|
access_token = "token"
|
|
mac_key = "secret"
|
|
|
|
bearer_query = uri + "&access_token=" + access_token
|
|
bearer_header = {
|
|
"Authorization": "Bearer " + access_token
|
|
}
|
|
bearer_body = body + "&access_token=" + access_token
|
|
|
|
mac_00_header = {
|
|
"Authorization": 'MAC id="' + access_token + '", nonce="0:abc123",' +
|
|
' bodyhash="Yqyso8r3hR5Nm1ZFv+6AvNHrxjE=",' +
|
|
' mac="0X6aACoBY0G6xgGZVJ1IeE8dF9k="'
|
|
}
|
|
mac_01_header = {
|
|
"Authorization": 'MAC id="' + access_token + '", ts="123456789",' +
|
|
' nonce="abc123", mac="Xuk+9oqaaKyhitkgh1CD0xrI6+s="'
|
|
}
|
|
|
|
def test_add_bearer_token(self):
|
|
"""Test a number of bearer token placements"""
|
|
|
|
# Invalid token type
|
|
client = Client(self.client_id, token_type="invalid")
|
|
self.assertRaises(ValueError, client.add_token, self.uri)
|
|
|
|
# Case-insensitive token type
|
|
client = Client(self.client_id, access_token=self.access_token, token_type="bEAreR")
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers)
|
|
self.assertURLEqual(uri, self.uri)
|
|
self.assertFormBodyEqual(body, self.body)
|
|
self.assertEqual(headers, self.bearer_header)
|
|
|
|
# Non-HTTPS
|
|
insecure_uri = 'http://example.com/path?query=world'
|
|
client = Client(self.client_id, access_token=self.access_token, token_type="Bearer")
|
|
self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
|
|
body=self.body,
|
|
headers=self.headers)
|
|
|
|
# Missing access token
|
|
client = Client(self.client_id)
|
|
self.assertRaises(ValueError, client.add_token, self.uri)
|
|
|
|
# Expired token
|
|
expired = 523549800
|
|
expired_token = {
|
|
'expires_at': expired,
|
|
}
|
|
client = Client(self.client_id, token=expired_token, access_token=self.access_token, token_type="Bearer")
|
|
self.assertRaises(TokenExpiredError, client.add_token, self.uri,
|
|
body=self.body, headers=self.headers)
|
|
|
|
# The default token placement, bearer in auth header
|
|
client = Client(self.client_id, access_token=self.access_token)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers)
|
|
self.assertURLEqual(uri, self.uri)
|
|
self.assertFormBodyEqual(body, self.body)
|
|
self.assertEqual(headers, self.bearer_header)
|
|
|
|
# Setting default placements of tokens
|
|
client = Client(self.client_id, access_token=self.access_token,
|
|
default_token_placement=AUTH_HEADER)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers)
|
|
self.assertURLEqual(uri, self.uri)
|
|
self.assertFormBodyEqual(body, self.body)
|
|
self.assertEqual(headers, self.bearer_header)
|
|
|
|
client = Client(self.client_id, access_token=self.access_token,
|
|
default_token_placement=URI_QUERY)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers)
|
|
self.assertURLEqual(uri, self.bearer_query)
|
|
self.assertFormBodyEqual(body, self.body)
|
|
self.assertEqual(headers, self.headers)
|
|
|
|
client = Client(self.client_id, access_token=self.access_token,
|
|
default_token_placement=BODY)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers)
|
|
self.assertURLEqual(uri, self.uri)
|
|
self.assertFormBodyEqual(body, self.bearer_body)
|
|
self.assertEqual(headers, self.headers)
|
|
|
|
# Asking for specific placement in the add_token method
|
|
client = Client(self.client_id, access_token=self.access_token)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers, token_placement=AUTH_HEADER)
|
|
self.assertURLEqual(uri, self.uri)
|
|
self.assertFormBodyEqual(body, self.body)
|
|
self.assertEqual(headers, self.bearer_header)
|
|
|
|
client = Client(self.client_id, access_token=self.access_token)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers, token_placement=URI_QUERY)
|
|
self.assertURLEqual(uri, self.bearer_query)
|
|
self.assertFormBodyEqual(body, self.body)
|
|
self.assertEqual(headers, self.headers)
|
|
|
|
client = Client(self.client_id, access_token=self.access_token)
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers, token_placement=BODY)
|
|
self.assertURLEqual(uri, self.uri)
|
|
self.assertFormBodyEqual(body, self.bearer_body)
|
|
self.assertEqual(headers, self.headers)
|
|
|
|
# Invalid token placement
|
|
client = Client(self.client_id, access_token=self.access_token)
|
|
self.assertRaises(ValueError, client.add_token, self.uri, body=self.body,
|
|
headers=self.headers, token_placement="invalid")
|
|
|
|
client = Client(self.client_id, access_token=self.access_token,
|
|
default_token_placement="invalid")
|
|
self.assertRaises(ValueError, client.add_token, self.uri, body=self.body,
|
|
headers=self.headers)
|
|
|
|
def test_add_mac_token(self):
|
|
# Missing access token
|
|
client = Client(self.client_id, token_type="MAC")
|
|
self.assertRaises(ValueError, client.add_token, self.uri)
|
|
|
|
# Invalid hash algorithm
|
|
client = Client(self.client_id, token_type="MAC",
|
|
access_token=self.access_token, mac_key=self.mac_key,
|
|
mac_algorithm="hmac-sha-2")
|
|
self.assertRaises(ValueError, client.add_token, self.uri)
|
|
|
|
orig_generate_timestamp = common.generate_timestamp
|
|
orig_generate_nonce = common.generate_nonce
|
|
orig_generate_age = utils.generate_age
|
|
self.addCleanup(setattr, common, 'generage_timestamp', orig_generate_timestamp)
|
|
self.addCleanup(setattr, common, 'generage_nonce', orig_generate_nonce)
|
|
self.addCleanup(setattr, utils, 'generate_age', orig_generate_age)
|
|
common.generate_timestamp = lambda: '123456789'
|
|
common.generate_nonce = lambda: 'abc123'
|
|
utils.generate_age = lambda *args: 0
|
|
|
|
# Add the Authorization header (draft 00)
|
|
client = Client(self.client_id, token_type="MAC",
|
|
access_token=self.access_token, mac_key=self.mac_key,
|
|
mac_algorithm="hmac-sha-1")
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers, issue_time=datetime.datetime.now())
|
|
self.assertEqual(uri, self.uri)
|
|
self.assertEqual(body, self.body)
|
|
self.assertEqual(headers, self.mac_00_header)
|
|
# Non-HTTPS
|
|
insecure_uri = 'http://example.com/path?query=world'
|
|
self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
|
|
body=self.body,
|
|
headers=self.headers,
|
|
issue_time=datetime.datetime.now())
|
|
# Expired Token
|
|
expired = 523549800
|
|
expired_token = {
|
|
'expires_at': expired,
|
|
}
|
|
client = Client(self.client_id, token=expired_token, token_type="MAC",
|
|
access_token=self.access_token, mac_key=self.mac_key,
|
|
mac_algorithm="hmac-sha-1")
|
|
self.assertRaises(TokenExpiredError, client.add_token, self.uri,
|
|
body=self.body,
|
|
headers=self.headers,
|
|
issue_time=datetime.datetime.now())
|
|
|
|
# Add the Authorization header (draft 01)
|
|
client = Client(self.client_id, token_type="MAC",
|
|
access_token=self.access_token, mac_key=self.mac_key,
|
|
mac_algorithm="hmac-sha-1")
|
|
uri, headers, body = client.add_token(self.uri, body=self.body,
|
|
headers=self.headers, draft=1)
|
|
self.assertEqual(uri, self.uri)
|
|
self.assertEqual(body, self.body)
|
|
self.assertEqual(headers, self.mac_01_header)
|
|
# Non-HTTPS
|
|
insecure_uri = 'http://example.com/path?query=world'
|
|
self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
|
|
body=self.body,
|
|
headers=self.headers,
|
|
draft=1)
|
|
# Expired Token
|
|
expired = 523549800
|
|
expired_token = {
|
|
'expires_at': expired,
|
|
}
|
|
client = Client(self.client_id, token=expired_token, token_type="MAC",
|
|
access_token=self.access_token, mac_key=self.mac_key,
|
|
mac_algorithm="hmac-sha-1")
|
|
self.assertRaises(TokenExpiredError, client.add_token, self.uri,
|
|
body=self.body,
|
|
headers=self.headers,
|
|
draft=1)
|
|
|
|
def test_revocation_request(self):
|
|
client = Client(self.client_id)
|
|
|
|
url = 'https://example.com/revoke'
|
|
token = 'foobar'
|
|
|
|
# Valid request
|
|
u, h, b = client.prepare_token_revocation_request(url, token)
|
|
self.assertEqual(u, url)
|
|
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
|
|
self.assertEqual(b, 'token=%s&token_type_hint=access_token' % token)
|
|
|
|
# Non-HTTPS revocation endpoint
|
|
self.assertRaises(InsecureTransportError,
|
|
client.prepare_token_revocation_request,
|
|
'http://example.com/revoke', token)
|
|
|
|
|
|
u, h, b = client.prepare_token_revocation_request(
|
|
url, token, token_type_hint='refresh_token')
|
|
self.assertEqual(u, url)
|
|
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
|
|
self.assertEqual(b, 'token=%s&token_type_hint=refresh_token' % token)
|
|
|
|
# JSONP
|
|
u, h, b = client.prepare_token_revocation_request(
|
|
url, token, callback='hello.world')
|
|
self.assertURLEqual(u, url + '?callback=hello.world&token=%s&token_type_hint=access_token' % token)
|
|
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
|
|
self.assertEqual(b, '')
|
|
|
|
def test_prepare_authorization_request(self):
|
|
redirect_url = 'https://example.com/callback/'
|
|
scopes = 'read'
|
|
auth_url = 'https://example.com/authorize/'
|
|
state = 'fake_state'
|
|
|
|
client = Client(self.client_id, redirect_url=redirect_url, scope=scopes, state=state)
|
|
|
|
# Non-HTTPS
|
|
self.assertRaises(InsecureTransportError,
|
|
client.prepare_authorization_request, 'http://example.com/authorize/')
|
|
|
|
# NotImplementedError
|
|
self.assertRaises(NotImplementedError, client.prepare_authorization_request, auth_url)
|
|
|
|
def test_prepare_token_request(self):
|
|
redirect_url = 'https://example.com/callback/'
|
|
scopes = 'read'
|
|
token_url = 'https://example.com/token/'
|
|
state = 'fake_state'
|
|
|
|
client = Client(self.client_id, scope=scopes, state=state)
|
|
|
|
# Non-HTTPS
|
|
self.assertRaises(InsecureTransportError,
|
|
client.prepare_token_request, 'http://example.com/token/')
|
|
|
|
# NotImplementedError
|
|
self.assertRaises(NotImplementedError, client.prepare_token_request, token_url)
|
|
|
|
def test_prepare_refresh_token_request(self):
|
|
client = Client(self.client_id)
|
|
|
|
url = 'https://example.com/revoke'
|
|
token = 'foobar'
|
|
scope = 'extra_scope'
|
|
|
|
u, h, b = client.prepare_refresh_token_request(url, token)
|
|
self.assertEqual(u, url)
|
|
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
|
|
self.assertFormBodyEqual(b, 'grant_type=refresh_token&refresh_token=%s' % token)
|
|
|
|
# Non-HTTPS revocation endpoint
|
|
self.assertRaises(InsecureTransportError,
|
|
client.prepare_refresh_token_request,
|
|
'http://example.com/revoke', token)
|
|
|
|
# provide extra scope
|
|
u, h, b = client.prepare_refresh_token_request(url, token, scope=scope)
|
|
self.assertEqual(u, url)
|
|
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
|
|
self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token))
|
|
|
|
# provide scope while init
|
|
client = Client(self.client_id, scope=scope)
|
|
u, h, b = client.prepare_refresh_token_request(url, token, scope=scope)
|
|
self.assertEqual(u, url)
|
|
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
|
|
self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token))
|
|
|
|
def test_create_code_verifier_min_length(self):
|
|
client = Client(self.client_id)
|
|
length = 43
|
|
code_verifier = client.create_code_verifier(length=length)
|
|
self.assertEqual(client.code_verifier, code_verifier)
|
|
|
|
def test_create_code_verifier_max_length(self):
|
|
client = Client(self.client_id)
|
|
length = 128
|
|
code_verifier = client.create_code_verifier(length=length)
|
|
self.assertEqual(client.code_verifier, code_verifier)
|
|
|
|
def test_create_code_verifier_length(self):
|
|
client = Client(self.client_id)
|
|
length = 96
|
|
code_verifier = client.create_code_verifier(length=length)
|
|
self.assertEqual(len(code_verifier), length)
|
|
|
|
def test_create_code_challenge_plain(self):
|
|
client = Client(self.client_id)
|
|
code_verifier = client.create_code_verifier(length=128)
|
|
code_challenge_plain = client.create_code_challenge(code_verifier=code_verifier)
|
|
|
|
# if no code_challenge_method specified, code_challenge = code_verifier
|
|
self.assertEqual(code_challenge_plain, client.code_verifier)
|
|
self.assertEqual(client.code_challenge_method, "plain")
|
|
|
|
def test_create_code_challenge_s256(self):
|
|
client = Client(self.client_id)
|
|
code_verifier = client.create_code_verifier(length=128)
|
|
code_challenge_s256 = client.create_code_challenge(code_verifier=code_verifier, code_challenge_method='S256')
|
|
self.assertEqual(code_challenge_s256, client.code_challenge)
|
|
|
|
def test_parse_token_response_expires_at_types(self):
|
|
for title, fieldjson, expected, generated in [
|
|
('int', 1661185148, 1661185148, 1661185148),
|
|
('float', 1661185148.6437678, 1661185148.6437678, 1661185148.6437678),
|
|
('str', "\"2006-01-02T15:04:05Z\"", "2006-01-02T15:04:05Z", None),
|
|
('str-as-int', "\"1661185148\"", 1661185148, 1661185148),
|
|
('str-as-float', "\"1661185148.42\"", 1661185148.42, 1661185148.42),
|
|
]:
|
|
with self.subTest(msg=title):
|
|
token_json = ('{{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
|
|
' "token_type":"example",'
|
|
' "expires_at":{expires_at},'
|
|
' "scope":"/profile",'
|
|
' "example_parameter":"example_value"}}'.format(expires_at=fieldjson))
|
|
|
|
client = Client(self.client_id)
|
|
response = client.parse_request_body_response(token_json, scope=["/profile"])
|
|
|
|
self.assertEqual(response['expires_at'], expected, "response attribute wrong")
|
|
self.assertEqual(client.expires_at, expected, "client attribute wrong")
|
|
if generated:
|
|
self.assertEqual(client._expires_at, generated, "internal expiration wrong")
|
|
|
|
@patch('time.time')
|
|
def test_parse_token_response_generated_expires_at_is_int(self, t):
|
|
t.return_value = 1661185148.6437678
|
|
expected_expires_at = round(t.return_value) + 3600
|
|
token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
|
|
' "token_type":"example",'
|
|
' "expires_in":3600,'
|
|
' "scope":"/profile",'
|
|
' "example_parameter":"example_value"}')
|
|
|
|
client = Client(self.client_id)
|
|
|
|
response = client.parse_request_body_response(token_json, scope=["/profile"])
|
|
|
|
self.assertEqual(response['expires_at'], expected_expires_at)
|
|
self.assertEqual(client._expires_at, expected_expires_at)
|