Add l2_norm normalization support to linear retriever (#128504)

* New l2 normalizer added

* L2 score normaliser is registered

* test case added to the yaml

* Documentation added

* Resolved checkstyle issues

* Update docs/changelog/128504.yaml

* Update docs/reference/elasticsearch/rest-apis/retrievers.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Score 0 test case added to check for corner cases

* Edited the markdown doc description

* Pruned the comment

* Renamed the variable

* Added comment to the class

* Unit tests added

* Spotless and checkstyle fixed

* Fixed build failure

* Fixed the forbidden test

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Mridula 2025-06-02 14:59:03 +01:00 committed by GitHub
parent 09ccd91b53
commit 81fba27b6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 214 additions and 1 deletions

View File

@ -0,0 +1,5 @@
pr: 128504
summary: Add l2_norm normalization support to linear retriever
area: Relevance
type: enhancement
issues: []

View File

@ -276,7 +276,7 @@ Each entry specifies the following parameters:
`normalizer`
: (Optional, String)
Specifies how we will normalize the retrievers scores, before applying the specified `weight`. Available values are: `minmax`, and `none`. Defaults to `none`.
- Specifies how we will normalize the retrievers scores, before applying the specified `weight`. Available values are: `minmax`, `l2_norm`, and `none`. Defaults to `none`.
* `none`
* `minmax` : A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
@ -285,6 +285,7 @@ Each entry specifies the following parameters:
score = (score - min) / (max - min)
```
* `l2_norm` : An `L2ScoreNormalizer` that normalizes scores using the L2 norm of the score values.
See also [this hybrid search example](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-retriever) using a linear retriever on how to independently configure and apply normalizers to retrievers.

View File

@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
/**
* A score normalizer that applies L2 normalization to a set of scores.
* <p>
* This normalizer scales the scores so that the L2 norm of the score vector is 1,
* if possible. If all scores are zero or NaN, normalization is skipped and the original scores are returned.
* </p>
*/
public class L2ScoreNormalizer extends ScoreNormalizer {
public static final L2ScoreNormalizer INSTANCE = new L2ScoreNormalizer();
public static final String NAME = "l2_norm";
private static final float EPSILON = 1e-6f;
public L2ScoreNormalizer() {}
@Override
public String getName() {
return NAME;
}
@Override
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
if (docs.length == 0) {
return docs;
}
double sumOfSquares = 0.0;
boolean atLeastOneValidScore = false;
for (ScoreDoc doc : docs) {
if (Float.isNaN(doc.score) == false) {
atLeastOneValidScore = true;
sumOfSquares += doc.score * doc.score;
}
}
if (atLeastOneValidScore == false) {
// No valid scores to normalize
return docs;
}
double norm = Math.sqrt(sumOfSquares);
if (norm < EPSILON) {
return docs;
}
ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
for (int i = 0; i < docs.length; i++) {
float score = (float) (docs[i].score / norm);
scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
}
return scoreDocs;
}
}

View File

@ -17,6 +17,9 @@ public abstract class ScoreNormalizer {
public static ScoreNormalizer valueOf(String normalizer) {
if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return MinMaxScoreNormalizer.INSTANCE;
} else if (L2ScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return L2ScoreNormalizer.INSTANCE;
} else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
return IdentityScoreNormalizer.INSTANCE;

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.linear;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.test.ESTestCase;
public class L2ScoreNormalizerTests extends ESTestCase {
public void testNormalizeTypicalVector() {
ScoreDoc[] docs = { new ScoreDoc(1, 3.0f, 0), new ScoreDoc(2, 4.0f, 0) };
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
assertEquals(0.6f, normalized[0].score, 1e-5);
assertEquals(0.8f, normalized[1].score, 1e-5);
}
public void testAllZeros() {
ScoreDoc[] docs = { new ScoreDoc(1, 0.0f, 0), new ScoreDoc(2, 0.0f, 0) };
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
assertEquals(0.0f, normalized[0].score, 0.0f);
assertEquals(0.0f, normalized[1].score, 0.0f);
}
public void testAllNaN() {
ScoreDoc[] docs = { new ScoreDoc(1, Float.NaN, 0), new ScoreDoc(2, Float.NaN, 0) };
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
assertTrue(Float.isNaN(normalized[0].score));
assertTrue(Float.isNaN(normalized[1].score));
}
public void testMixedZeroAndNaN() {
ScoreDoc[] docs = { new ScoreDoc(1, 0.0f, 0), new ScoreDoc(2, Float.NaN, 0) };
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
assertEquals(0.0f, normalized[0].score, 0.0f);
assertTrue(Float.isNaN(normalized[1].score));
}
public void testSingleElement() {
ScoreDoc[] docs = { new ScoreDoc(1, 42.0f, 0) };
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
assertEquals(1.0f, normalized[0].score, 1e-5);
}
public void testEmptyArray() {
ScoreDoc[] docs = {};
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
assertEquals(0, normalized.length);
}
}

View File

@ -265,6 +265,93 @@ setup:
- match: { hits.hits.3._id: "3" }
- close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } }
---
"should normalize initial scores with l2_norm":
- do:
search:
index: test
body:
retriever:
linear:
retrievers: [
{
retriever: {
standard: {
query: {
bool: {
should: [
{ constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 3.0 } },
{ constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 4.0 } }
]
}
}
}
},
weight: 10.0,
normalizer: "l2_norm"
},
{
retriever: {
standard: {
query: {
bool: {
should: [
{ constant_score: { filter: { term: { keyword: { value: "three" } } }, boost: 6.0 } },
{ constant_score: { filter: { term: { keyword: { value: "four" } } }, boost: 8.0 } }
]
}
}
}
},
weight: 2.0,
normalizer: "l2_norm"
}
]
- match: { hits.total.value: 4 }
- match: { hits.hits.0._id: "2" }
- match: { hits.hits.0._score: 8.0 }
- match: { hits.hits.1._id: "1" }
- match: { hits.hits.1._score: 6.0 }
- match: { hits.hits.2._id: "4" }
- close_to: { hits.hits.2._score: { value: 1.6, error: 0.001 } }
- match: { hits.hits.3._id: "3" }
- match: { hits.hits.3._score: 1.2 }
---
"should handle all zero scores in normalization":
- do:
search:
index: test
body:
retriever:
linear:
retrievers: [
{
retriever: {
standard: {
query: {
bool: {
should: [
{ constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 0.0 } },
{ constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 0.0 } },
{ constant_score: { filter: { term: { keyword: { value: "three" } } }, boost: 0.0 } },
{ constant_score: { filter: { term: { keyword: { value: "four" } } }, boost: 0.0 } }
]
}
}
}
},
weight: 1.0,
normalizer: "l2_norm"
}
]
- match: { hits.total.value: 4 }
- close_to: { hits.hits.0._score: { value: 0.0, error: 0.0001 } }
- close_to: { hits.hits.1._score: { value: 0.0, error: 0.0001 } }
- close_to: { hits.hits.2._score: { value: 0.0, error: 0.0001 } }
- close_to: { hits.hits.3._score: { value: 0.0, error: 0.0001 } }
---
"should throw on unknown normalizer":
- do: