permit at+jwt typ header value in jwt access tokens (#126687)

* permit at+jwt typ header value in jwt access tokens

* Update docs/changelog/126687.yaml

* address review comments

* [CI] Auto commit changes from spotless

* update Type Validator tests for parser ignoring case

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
Richard Dennehy 2025-04-15 11:08:30 +01:00 committed by GitHub
parent 7de46e9897
commit 9e3476ef99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 92 additions and 27 deletions

View File

@ -0,0 +1,6 @@
pr: 126687
summary: Permit at+jwt typ header value in jwt access tokens
area: Authentication
type: enhancement
issues:
- 119370

View File

@ -456,13 +456,13 @@ public class JwtRestIT extends ESRestTestCase {
{
// This is the correct HMAC passphrase (from build.gradle)
final SignedJWT jwt = signHmacJwt(claimsSet, HMAC_PASSPHRASE);
final SignedJWT jwt = signHmacJwt(claimsSet, HMAC_PASSPHRASE, false);
final TestSecurityClient client = getSecurityClient(jwt, Optional.of(VALID_SHARED_SECRET));
assertThat(client.authenticate(), hasEntry(User.Fields.USERNAME.getPreferredName(), username));
}
{
// This is not the correct HMAC passphrase
final SignedJWT invalidJwt = signHmacJwt(claimsSet, "invalid-HMAC-passphrase-" + randomAlphaOfLength(12));
final SignedJWT invalidJwt = signHmacJwt(claimsSet, "invalid-HMAC-passphrase-" + randomAlphaOfLength(12), false);
final TestSecurityClient client = getSecurityClient(invalidJwt, Optional.of(VALID_SHARED_SECRET));
// This fails because the HMAC is wrong
final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
@ -487,7 +487,7 @@ public class JwtRestIT extends ESRestTestCase {
data.put("token_use", randomValueOtherThan("access", () -> randomAlphaOfLengthBetween(3, 10)));
}
final JWTClaimsSet claimsSet = buildJwt(data, Instant.now(), false, false);
final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value", false);
final TestSecurityClient client = getSecurityClient(jwt, Optional.of(VALID_SHARED_SECRET));
final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
assertThat(exception.getResponse(), hasStatusCode(RestStatus.UNAUTHORIZED));
@ -747,18 +747,18 @@ public class JwtRestIT extends ESRestTestCase {
private SignedJWT signJwtForRealm1(JWTClaimsSet claimsSet) throws IOException, JOSEException, ParseException {
final RSASSASigner signer = loadRsaSigner();
return signJWT(signer, "RS256", claimsSet);
return signJWT(signer, "RS256", claimsSet, false);
}
private SignedJWT signJwtForRealm2(JWTClaimsSet claimsSet) throws JOSEException, ParseException {
private SignedJWT signJwtForRealm2(JWTClaimsSet claimsSet) throws JOSEException {
// Input string is configured in build.gradle
return signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
return signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value", true);
}
private SignedJWT signJwtForRealm3(JWTClaimsSet claimsSet) throws JOSEException, ParseException, IOException {
final int bitSize = randomFrom(384, 512);
final MACSigner signer = loadHmacSigner("test-hmac-" + bitSize);
return signJWT(signer, "HS" + bitSize, claimsSet);
return signJWT(signer, "HS" + bitSize, claimsSet, false);
}
private RSASSASigner loadRsaSigner() throws IOException, ParseException, JOSEException {
@ -781,10 +781,10 @@ public class JwtRestIT extends ESRestTestCase {
}
}
private SignedJWT signHmacJwt(JWTClaimsSet claimsSet, String hmacPassphrase) throws JOSEException {
private SignedJWT signHmacJwt(JWTClaimsSet claimsSet, String hmacPassphrase, boolean allowAtJwtType) throws JOSEException {
final OctetSequenceKey hmac = JwkValidateUtil.buildHmacKeyFromString(hmacPassphrase);
final JWSSigner signer = new MACSigner(hmac);
return signJWT(signer, "HS256", claimsSet);
return signJWT(signer, "HS256", claimsSet, allowAtJwtType);
}
// JWT construction
@ -822,10 +822,14 @@ public class JwtRestIT extends ESRestTestCase {
return builder.build();
}
static SignedJWT signJWT(JWSSigner signer, String algorithm, JWTClaimsSet claimsSet) throws JOSEException {
static SignedJWT signJWT(JWSSigner signer, String algorithm, JWTClaimsSet claimsSet, boolean allowAtJwtType) throws JOSEException {
final JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(algorithm));
if (randomBoolean()) {
builder.type(JOSEObjectType.JWT);
if (allowAtJwtType && randomBoolean()) {
builder.type(new JOSEObjectType("at+jwt"));
} else {
builder.type(JOSEObjectType.JWT);
}
}
final JWSHeader jwtHeader = builder.build();
final SignedJWT jwt = new SignedJWT(jwtHeader, claimsSet);

View File

@ -279,7 +279,7 @@ public class JwtWithUnavailableSecurityIndexRestIT extends ESRestTestCase {
issueTime
);
final RSASSASigner signer = loadRsaSigner();
return JwtRestIT.signJWT(signer, "RS256", claimsSet);
return JwtRestIT.signJWT(signer, "RS256", claimsSet, false);
}
private RSASSASigner loadRsaSigner() throws IOException, ParseException, JOSEException {

View File

@ -136,7 +136,7 @@ public class JwtAuthenticator implements Releasable {
}
return List.of(
JwtTypeValidator.INSTANCE,
JwtTypeValidator.ID_TOKEN_INSTANCE,
new JwtStringClaimValidator("iss", true, List.of(realmConfig.getSetting(JwtRealmSettings.ALLOWED_ISSUER)), List.of()),
subjectClaimValidator,
new JwtStringClaimValidator("aud", false, realmConfig.getSetting(JwtRealmSettings.ALLOWED_AUDIENCES), List.of()),
@ -157,7 +157,7 @@ public class JwtAuthenticator implements Releasable {
final Clock clock = Clock.systemUTC();
return List.of(
JwtTypeValidator.INSTANCE,
JwtTypeValidator.ACCESS_TOKEN_INSTANCE,
new JwtStringClaimValidator("iss", true, List.of(realmConfig.getSetting(JwtRealmSettings.ALLOWED_ISSUER)), List.of()),
getSubjectClaimValidator(realmConfig, fallbackClaimLookup),
new JwtStringClaimValidator(

View File

@ -17,14 +17,17 @@ import com.nimbusds.jwt.JWTClaimsSet;
public class JwtTypeValidator implements JwtFieldValidator {
private static final JOSEObjectTypeVerifier<SecurityContext> JWT_HEADER_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier<>(
JOSEObjectType.JWT,
null
);
private final JOSEObjectTypeVerifier<SecurityContext> JWT_HEADER_TYPE_VERIFIER;
private static final JOSEObjectType AT_PLUS_JWT = new JOSEObjectType("at+jwt");
public static final JwtTypeValidator INSTANCE = new JwtTypeValidator();
public static final JwtTypeValidator ID_TOKEN_INSTANCE = new JwtTypeValidator(JOSEObjectType.JWT, null);
private JwtTypeValidator() {}
// strictly speaking, this should only permit `at+jwt`, but removing the other two options is a breaking change
public static final JwtTypeValidator ACCESS_TOKEN_INSTANCE = new JwtTypeValidator(JOSEObjectType.JWT, AT_PLUS_JWT, null);
private JwtTypeValidator(JOSEObjectType... allowedTypes) {
JWT_HEADER_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier<>(allowedTypes);
}
public void validate(JWSHeader jwsHeader, JWTClaimsSet jwtClaimsSet) {
final JOSEObjectType jwtHeaderType = jwsHeader.getType();

View File

@ -7,11 +7,21 @@
package org.elasticsearch.xpack.security.authc.jwt;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
import java.text.ParseException;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class JwtAuthenticatorIdTokenTypeTests extends JwtAuthenticatorTests {
@ -28,4 +38,23 @@ public class JwtAuthenticatorIdTokenTypeTests extends JwtAuthenticatorTests {
public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator());
}
public void testAccessTokenHeaderTypeIsRejected() throws ParseException {
final JWTClaimsSet claimsSet = JWTClaimsSet.parse(Map.of());
final SignedJWT signedJWT = new SignedJWT(
JWSHeader.parse(Map.of("alg", allowedAlgorithm, "typ", "at+jwt")).toBase64URL(),
claimsSet.toPayload().toBase64URL(),
Base64URL.encode("signature")
);
final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());
final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator();
jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
final Exception e = expectThrows(IllegalArgumentException.class, future::actionGet);
assertThat(e.getMessage(), equalTo("invalid jwt typ header"));
}
}

View File

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.security.authc.jwt;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.openid.connect.sdk.Nonce;
@ -134,7 +133,7 @@ public class JwtRealmAuthenticateAccessTokenTypeTests extends JwtRealmTestCase {
final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
unsignedJwt = JwtTestCase.buildUnsignedJwt(
randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
randomFrom("at+jwt", "JWT", null), // typ
randomBoolean() ? null : jwk.getKeyID(), // kid
algJwkPair.alg(), // alg
randomAlphaOfLengthBetween(10, 20), // jwtID

View File

@ -19,22 +19,46 @@ import static org.hamcrest.Matchers.containsString;
public class JwtTypeValidatorTests extends ESTestCase {
public void testValidType() throws ParseException {
public void testValidIdTokenType() throws ParseException {
final String algorithm = randomAlphaOfLengthBetween(3, 8);
// typ is allowed to be missing
final JWSHeader jwsHeader = JWSHeader.parse(
randomFrom(Map.of("alg", randomAlphaOfLengthBetween(3, 8)), Map.of("typ", "JWT", "alg", randomAlphaOfLengthBetween(3, 8)))
randomFrom(
// typ is allowed to be missing
Map.of("alg", algorithm),
Map.of("typ", "JWT", "alg", algorithm)
)
);
try {
JwtTypeValidator.INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()));
JwtTypeValidator.ID_TOKEN_INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()));
} catch (Exception e) {
throw new AssertionError("validation should have passed without exception", e);
}
}
public void testValidAccessTokenType() throws ParseException {
final String algorithm = randomAlphaOfLengthBetween(3, 8);
final JWSHeader jwsHeader = JWSHeader.parse(
randomFrom(
// typ is allowed to be missing
Map.of("alg", algorithm),
Map.of("typ", "JWT", "alg", algorithm),
Map.of("typ", "at+jwt", "alg", algorithm),
Map.of("typ", "AT+JWT", "alg", algorithm)
)
);
try {
JwtTypeValidator.ACCESS_TOKEN_INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()));
} catch (Exception e) {
throw new AssertionError("validation should have passed without exception", e);
}
}
public void testInvalidType() throws ParseException {
final JwtTypeValidator validator = randomFrom(JwtTypeValidator.ID_TOKEN_INSTANCE, JwtTypeValidator.ACCESS_TOKEN_INSTANCE);
final JWSHeader jwsHeader = JWSHeader.parse(
Map.of("typ", randomAlphaOfLengthBetween(4, 8), "alg", randomAlphaOfLengthBetween(3, 8))
@ -42,7 +66,7 @@ public class JwtTypeValidatorTests extends ESTestCase {
final IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> JwtTypeValidator.INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()))
() -> validator.validate(jwsHeader, JWTClaimsSet.parse(Map.of()))
);
assertThat(e.getMessage(), containsString("invalid jwt typ header"));
}