python-oauthlib/tests/oauth2/rfc6749/clients/test_base.py

379 lines
16 KiB
Python

# -*- coding: utf-8 -*-
import datetime
import json
from unittest.mock import patch
from oauthlib import common
from oauthlib.oauth2 import Client, InsecureTransportError, TokenExpiredError
from oauthlib.oauth2.rfc6749 import utils
from oauthlib.oauth2.rfc6749.clients import AUTH_HEADER, BODY, URI_QUERY
from tests.unittest import TestCase
class ClientTest(TestCase):
client_id = "someclientid"
uri = "https://example.com/path?query=world"
body = "not=empty"
headers = {}
access_token = "token"
mac_key = "secret"
bearer_query = uri + "&access_token=" + access_token
bearer_header = {
"Authorization": "Bearer " + access_token
}
bearer_body = body + "&access_token=" + access_token
mac_00_header = {
"Authorization": 'MAC id="' + access_token + '", nonce="0:abc123",' +
' bodyhash="Yqyso8r3hR5Nm1ZFv+6AvNHrxjE=",' +
' mac="0X6aACoBY0G6xgGZVJ1IeE8dF9k="'
}
mac_01_header = {
"Authorization": 'MAC id="' + access_token + '", ts="123456789",' +
' nonce="abc123", mac="Xuk+9oqaaKyhitkgh1CD0xrI6+s="'
}
def test_add_bearer_token(self):
"""Test a number of bearer token placements"""
# Invalid token type
client = Client(self.client_id, token_type="invalid")
self.assertRaises(ValueError, client.add_token, self.uri)
# Case-insensitive token type
client = Client(self.client_id, access_token=self.access_token, token_type="bEAreR")
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers)
self.assertURLEqual(uri, self.uri)
self.assertFormBodyEqual(body, self.body)
self.assertEqual(headers, self.bearer_header)
# Non-HTTPS
insecure_uri = 'http://example.com/path?query=world'
client = Client(self.client_id, access_token=self.access_token, token_type="Bearer")
self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
body=self.body,
headers=self.headers)
# Missing access token
client = Client(self.client_id)
self.assertRaises(ValueError, client.add_token, self.uri)
# Expired token
expired = 523549800
expired_token = {
'expires_at': expired,
}
client = Client(self.client_id, token=expired_token, access_token=self.access_token, token_type="Bearer")
self.assertRaises(TokenExpiredError, client.add_token, self.uri,
body=self.body, headers=self.headers)
# The default token placement, bearer in auth header
client = Client(self.client_id, access_token=self.access_token)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers)
self.assertURLEqual(uri, self.uri)
self.assertFormBodyEqual(body, self.body)
self.assertEqual(headers, self.bearer_header)
# Setting default placements of tokens
client = Client(self.client_id, access_token=self.access_token,
default_token_placement=AUTH_HEADER)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers)
self.assertURLEqual(uri, self.uri)
self.assertFormBodyEqual(body, self.body)
self.assertEqual(headers, self.bearer_header)
client = Client(self.client_id, access_token=self.access_token,
default_token_placement=URI_QUERY)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers)
self.assertURLEqual(uri, self.bearer_query)
self.assertFormBodyEqual(body, self.body)
self.assertEqual(headers, self.headers)
client = Client(self.client_id, access_token=self.access_token,
default_token_placement=BODY)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers)
self.assertURLEqual(uri, self.uri)
self.assertFormBodyEqual(body, self.bearer_body)
self.assertEqual(headers, self.headers)
# Asking for specific placement in the add_token method
client = Client(self.client_id, access_token=self.access_token)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers, token_placement=AUTH_HEADER)
self.assertURLEqual(uri, self.uri)
self.assertFormBodyEqual(body, self.body)
self.assertEqual(headers, self.bearer_header)
client = Client(self.client_id, access_token=self.access_token)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers, token_placement=URI_QUERY)
self.assertURLEqual(uri, self.bearer_query)
self.assertFormBodyEqual(body, self.body)
self.assertEqual(headers, self.headers)
client = Client(self.client_id, access_token=self.access_token)
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers, token_placement=BODY)
self.assertURLEqual(uri, self.uri)
self.assertFormBodyEqual(body, self.bearer_body)
self.assertEqual(headers, self.headers)
# Invalid token placement
client = Client(self.client_id, access_token=self.access_token)
self.assertRaises(ValueError, client.add_token, self.uri, body=self.body,
headers=self.headers, token_placement="invalid")
client = Client(self.client_id, access_token=self.access_token,
default_token_placement="invalid")
self.assertRaises(ValueError, client.add_token, self.uri, body=self.body,
headers=self.headers)
def test_add_mac_token(self):
# Missing access token
client = Client(self.client_id, token_type="MAC")
self.assertRaises(ValueError, client.add_token, self.uri)
# Invalid hash algorithm
client = Client(self.client_id, token_type="MAC",
access_token=self.access_token, mac_key=self.mac_key,
mac_algorithm="hmac-sha-2")
self.assertRaises(ValueError, client.add_token, self.uri)
orig_generate_timestamp = common.generate_timestamp
orig_generate_nonce = common.generate_nonce
orig_generate_age = utils.generate_age
self.addCleanup(setattr, common, 'generage_timestamp', orig_generate_timestamp)
self.addCleanup(setattr, common, 'generage_nonce', orig_generate_nonce)
self.addCleanup(setattr, utils, 'generate_age', orig_generate_age)
common.generate_timestamp = lambda: '123456789'
common.generate_nonce = lambda: 'abc123'
utils.generate_age = lambda *args: 0
# Add the Authorization header (draft 00)
client = Client(self.client_id, token_type="MAC",
access_token=self.access_token, mac_key=self.mac_key,
mac_algorithm="hmac-sha-1")
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers, issue_time=datetime.datetime.now())
self.assertEqual(uri, self.uri)
self.assertEqual(body, self.body)
self.assertEqual(headers, self.mac_00_header)
# Non-HTTPS
insecure_uri = 'http://example.com/path?query=world'
self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
body=self.body,
headers=self.headers,
issue_time=datetime.datetime.now())
# Expired Token
expired = 523549800
expired_token = {
'expires_at': expired,
}
client = Client(self.client_id, token=expired_token, token_type="MAC",
access_token=self.access_token, mac_key=self.mac_key,
mac_algorithm="hmac-sha-1")
self.assertRaises(TokenExpiredError, client.add_token, self.uri,
body=self.body,
headers=self.headers,
issue_time=datetime.datetime.now())
# Add the Authorization header (draft 01)
client = Client(self.client_id, token_type="MAC",
access_token=self.access_token, mac_key=self.mac_key,
mac_algorithm="hmac-sha-1")
uri, headers, body = client.add_token(self.uri, body=self.body,
headers=self.headers, draft=1)
self.assertEqual(uri, self.uri)
self.assertEqual(body, self.body)
self.assertEqual(headers, self.mac_01_header)
# Non-HTTPS
insecure_uri = 'http://example.com/path?query=world'
self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
body=self.body,
headers=self.headers,
draft=1)
# Expired Token
expired = 523549800
expired_token = {
'expires_at': expired,
}
client = Client(self.client_id, token=expired_token, token_type="MAC",
access_token=self.access_token, mac_key=self.mac_key,
mac_algorithm="hmac-sha-1")
self.assertRaises(TokenExpiredError, client.add_token, self.uri,
body=self.body,
headers=self.headers,
draft=1)
def test_revocation_request(self):
client = Client(self.client_id)
url = 'https://example.com/revoke'
token = 'foobar'
# Valid request
u, h, b = client.prepare_token_revocation_request(url, token)
self.assertEqual(u, url)
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
self.assertEqual(b, 'token=%s&token_type_hint=access_token' % token)
# Non-HTTPS revocation endpoint
self.assertRaises(InsecureTransportError,
client.prepare_token_revocation_request,
'http://example.com/revoke', token)
u, h, b = client.prepare_token_revocation_request(
url, token, token_type_hint='refresh_token')
self.assertEqual(u, url)
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
self.assertEqual(b, 'token=%s&token_type_hint=refresh_token' % token)
# JSONP
u, h, b = client.prepare_token_revocation_request(
url, token, callback='hello.world')
self.assertURLEqual(u, url + '?callback=hello.world&token=%s&token_type_hint=access_token' % token)
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
self.assertEqual(b, '')
def test_prepare_authorization_request(self):
redirect_url = 'https://example.com/callback/'
scopes = 'read'
auth_url = 'https://example.com/authorize/'
state = 'fake_state'
client = Client(self.client_id, redirect_url=redirect_url, scope=scopes, state=state)
# Non-HTTPS
self.assertRaises(InsecureTransportError,
client.prepare_authorization_request, 'http://example.com/authorize/')
# NotImplementedError
self.assertRaises(NotImplementedError, client.prepare_authorization_request, auth_url)
def test_prepare_token_request(self):
redirect_url = 'https://example.com/callback/'
scopes = 'read'
token_url = 'https://example.com/token/'
state = 'fake_state'
client = Client(self.client_id, scope=scopes, state=state)
# Non-HTTPS
self.assertRaises(InsecureTransportError,
client.prepare_token_request, 'http://example.com/token/')
# NotImplementedError
self.assertRaises(NotImplementedError, client.prepare_token_request, token_url)
def test_prepare_refresh_token_request(self):
client = Client(self.client_id)
url = 'https://example.com/revoke'
token = 'foobar'
scope = 'extra_scope'
u, h, b = client.prepare_refresh_token_request(url, token)
self.assertEqual(u, url)
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
self.assertFormBodyEqual(b, 'grant_type=refresh_token&refresh_token=%s' % token)
# Non-HTTPS revocation endpoint
self.assertRaises(InsecureTransportError,
client.prepare_refresh_token_request,
'http://example.com/revoke', token)
# provide extra scope
u, h, b = client.prepare_refresh_token_request(url, token, scope=scope)
self.assertEqual(u, url)
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token))
# provide scope while init
client = Client(self.client_id, scope=scope)
u, h, b = client.prepare_refresh_token_request(url, token, scope=scope)
self.assertEqual(u, url)
self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token))
def test_create_code_verifier_min_length(self):
client = Client(self.client_id)
length = 43
code_verifier = client.create_code_verifier(length=length)
self.assertEqual(client.code_verifier, code_verifier)
def test_create_code_verifier_max_length(self):
client = Client(self.client_id)
length = 128
code_verifier = client.create_code_verifier(length=length)
self.assertEqual(client.code_verifier, code_verifier)
def test_create_code_verifier_length(self):
client = Client(self.client_id)
length = 96
code_verifier = client.create_code_verifier(length=length)
self.assertEqual(len(code_verifier), length)
def test_create_code_challenge_plain(self):
client = Client(self.client_id)
code_verifier = client.create_code_verifier(length=128)
code_challenge_plain = client.create_code_challenge(code_verifier=code_verifier)
# if no code_challenge_method specified, code_challenge = code_verifier
self.assertEqual(code_challenge_plain, client.code_verifier)
self.assertEqual(client.code_challenge_method, "plain")
def test_create_code_challenge_s256(self):
client = Client(self.client_id)
code_verifier = client.create_code_verifier(length=128)
code_challenge_s256 = client.create_code_challenge(code_verifier=code_verifier, code_challenge_method='S256')
self.assertEqual(code_challenge_s256, client.code_challenge)
def test_parse_token_response_expires_at_types(self):
for title, fieldjson, expected, generated in [
('int', 1661185148, 1661185148, 1661185148),
('float', 1661185148.6437678, 1661185148.6437678, 1661185148.6437678),
('str', "\"2006-01-02T15:04:05Z\"", "2006-01-02T15:04:05Z", None),
('str-as-int', "\"1661185148\"", 1661185148, 1661185148),
('str-as-float', "\"1661185148.42\"", 1661185148.42, 1661185148.42),
]:
with self.subTest(msg=title):
token_json = ('{{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
' "token_type":"example",'
' "expires_at":{expires_at},'
' "scope":"/profile",'
' "example_parameter":"example_value"}}'.format(expires_at=fieldjson))
client = Client(self.client_id)
response = client.parse_request_body_response(token_json, scope=["/profile"])
self.assertEqual(response['expires_at'], expected, "response attribute wrong")
self.assertEqual(client.expires_at, expected, "client attribute wrong")
if generated:
self.assertEqual(client._expires_at, generated, "internal expiration wrong")
@patch('time.time')
def test_parse_token_response_generated_expires_at_is_int(self, t):
t.return_value = 1661185148.6437678
expected_expires_at = round(t.return_value) + 3600
token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
' "token_type":"example",'
' "expires_in":3600,'
' "scope":"/profile",'
' "example_parameter":"example_value"}')
client = Client(self.client_id)
response = client.parse_request_body_response(token_json, scope=["/profile"])
self.assertEqual(response['expires_at'], expected_expires_at)
self.assertEqual(client._expires_at, expected_expires_at)