Add l2_norm normalization support to linear retriever (#128504)
* New l2 normalizer added * L2 score normaliser is registered * test case added to the yaml * Documentation added * Resolved checkstyle issues * Update docs/changelog/128504.yaml * Update docs/reference/elasticsearch/rest-apis/retrievers.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Score 0 test case added to check for corner cases * Edited the markdown doc description * Pruned the comment * Renamed the variable * Added comment to the class * Unit tests added * Spotless and checkstyle fixed * Fixed build failure * Fixed the forbidden test --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
09ccd91b53
commit
81fba27b6b
|
@ -0,0 +1,5 @@
|
|||
pr: 128504
|
||||
summary: Add l2_norm normalization support to linear retriever
|
||||
area: Relevance
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -276,7 +276,7 @@ Each entry specifies the following parameters:
|
|||
`normalizer`
|
||||
: (Optional, String)
|
||||
|
||||
Specifies how we will normalize the retriever’s scores, before applying the specified `weight`. Available values are: `minmax`, and `none`. Defaults to `none`.
|
||||
- Specifies how we will normalize the retriever’s scores, before applying the specified `weight`. Available values are: `minmax`, `l2_norm`, and `none`. Defaults to `none`.
|
||||
|
||||
* `none`
|
||||
* `minmax` : A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
|
||||
|
@ -285,6 +285,7 @@ Each entry specifies the following parameters:
|
|||
score = (score - min) / (max - min)
|
||||
```
|
||||
|
||||
* `l2_norm` : An `L2ScoreNormalizer` that normalizes scores using the L2 norm of the score values.
|
||||
|
||||
See also [this hybrid search example](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-retriever) using a linear retriever on how to independently configure and apply normalizers to retrievers.
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
||||
/**
|
||||
* A score normalizer that applies L2 normalization to a set of scores.
|
||||
* <p>
|
||||
* This normalizer scales the scores so that the L2 norm of the score vector is 1,
|
||||
* if possible. If all scores are zero or NaN, normalization is skipped and the original scores are returned.
|
||||
* </p>
|
||||
*/
|
||||
public class L2ScoreNormalizer extends ScoreNormalizer {
|
||||
|
||||
public static final L2ScoreNormalizer INSTANCE = new L2ScoreNormalizer();
|
||||
|
||||
public static final String NAME = "l2_norm";
|
||||
|
||||
private static final float EPSILON = 1e-6f;
|
||||
|
||||
public L2ScoreNormalizer() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
|
||||
if (docs.length == 0) {
|
||||
return docs;
|
||||
}
|
||||
double sumOfSquares = 0.0;
|
||||
boolean atLeastOneValidScore = false;
|
||||
for (ScoreDoc doc : docs) {
|
||||
if (Float.isNaN(doc.score) == false) {
|
||||
atLeastOneValidScore = true;
|
||||
sumOfSquares += doc.score * doc.score;
|
||||
}
|
||||
}
|
||||
if (atLeastOneValidScore == false) {
|
||||
// No valid scores to normalize
|
||||
return docs;
|
||||
}
|
||||
double norm = Math.sqrt(sumOfSquares);
|
||||
if (norm < EPSILON) {
|
||||
return docs;
|
||||
}
|
||||
ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
|
||||
for (int i = 0; i < docs.length; i++) {
|
||||
float score = (float) (docs[i].score / norm);
|
||||
scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
|
||||
}
|
||||
return scoreDocs;
|
||||
}
|
||||
}
|
|
@ -17,6 +17,9 @@ public abstract class ScoreNormalizer {
|
|||
public static ScoreNormalizer valueOf(String normalizer) {
|
||||
if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||
return MinMaxScoreNormalizer.INSTANCE;
|
||||
} else if (L2ScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||
return L2ScoreNormalizer.INSTANCE;
|
||||
|
||||
} else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
|
||||
return IdentityScoreNormalizer.INSTANCE;
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.linear;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
public class L2ScoreNormalizerTests extends ESTestCase {
|
||||
|
||||
public void testNormalizeTypicalVector() {
|
||||
ScoreDoc[] docs = { new ScoreDoc(1, 3.0f, 0), new ScoreDoc(2, 4.0f, 0) };
|
||||
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
|
||||
assertEquals(0.6f, normalized[0].score, 1e-5);
|
||||
assertEquals(0.8f, normalized[1].score, 1e-5);
|
||||
}
|
||||
|
||||
public void testAllZeros() {
|
||||
ScoreDoc[] docs = { new ScoreDoc(1, 0.0f, 0), new ScoreDoc(2, 0.0f, 0) };
|
||||
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
|
||||
assertEquals(0.0f, normalized[0].score, 0.0f);
|
||||
assertEquals(0.0f, normalized[1].score, 0.0f);
|
||||
}
|
||||
|
||||
public void testAllNaN() {
|
||||
ScoreDoc[] docs = { new ScoreDoc(1, Float.NaN, 0), new ScoreDoc(2, Float.NaN, 0) };
|
||||
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
|
||||
assertTrue(Float.isNaN(normalized[0].score));
|
||||
assertTrue(Float.isNaN(normalized[1].score));
|
||||
}
|
||||
|
||||
public void testMixedZeroAndNaN() {
|
||||
ScoreDoc[] docs = { new ScoreDoc(1, 0.0f, 0), new ScoreDoc(2, Float.NaN, 0) };
|
||||
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
|
||||
assertEquals(0.0f, normalized[0].score, 0.0f);
|
||||
assertTrue(Float.isNaN(normalized[1].score));
|
||||
}
|
||||
|
||||
public void testSingleElement() {
|
||||
ScoreDoc[] docs = { new ScoreDoc(1, 42.0f, 0) };
|
||||
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
|
||||
assertEquals(1.0f, normalized[0].score, 1e-5);
|
||||
}
|
||||
|
||||
public void testEmptyArray() {
|
||||
ScoreDoc[] docs = {};
|
||||
ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
|
||||
assertEquals(0, normalized.length);
|
||||
}
|
||||
}
|
|
@ -265,6 +265,93 @@ setup:
|
|||
- match: { hits.hits.3._id: "3" }
|
||||
- close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } }
|
||||
|
||||
---
|
||||
"should normalize initial scores with l2_norm":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
linear:
|
||||
retrievers: [
|
||||
{
|
||||
retriever: {
|
||||
standard: {
|
||||
query: {
|
||||
bool: {
|
||||
should: [
|
||||
{ constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 3.0 } },
|
||||
{ constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 4.0 } }
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
weight: 10.0,
|
||||
normalizer: "l2_norm"
|
||||
},
|
||||
{
|
||||
retriever: {
|
||||
standard: {
|
||||
query: {
|
||||
bool: {
|
||||
should: [
|
||||
{ constant_score: { filter: { term: { keyword: { value: "three" } } }, boost: 6.0 } },
|
||||
{ constant_score: { filter: { term: { keyword: { value: "four" } } }, boost: 8.0 } }
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
weight: 2.0,
|
||||
normalizer: "l2_norm"
|
||||
}
|
||||
]
|
||||
|
||||
- match: { hits.total.value: 4 }
|
||||
- match: { hits.hits.0._id: "2" }
|
||||
- match: { hits.hits.0._score: 8.0 }
|
||||
- match: { hits.hits.1._id: "1" }
|
||||
- match: { hits.hits.1._score: 6.0 }
|
||||
- match: { hits.hits.2._id: "4" }
|
||||
- close_to: { hits.hits.2._score: { value: 1.6, error: 0.001 } }
|
||||
- match: { hits.hits.3._id: "3" }
|
||||
- match: { hits.hits.3._score: 1.2 }
|
||||
|
||||
---
|
||||
"should handle all zero scores in normalization":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
retriever:
|
||||
linear:
|
||||
retrievers: [
|
||||
{
|
||||
retriever: {
|
||||
standard: {
|
||||
query: {
|
||||
bool: {
|
||||
should: [
|
||||
{ constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 0.0 } },
|
||||
{ constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 0.0 } },
|
||||
{ constant_score: { filter: { term: { keyword: { value: "three" } } }, boost: 0.0 } },
|
||||
{ constant_score: { filter: { term: { keyword: { value: "four" } } }, boost: 0.0 } }
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
weight: 1.0,
|
||||
normalizer: "l2_norm"
|
||||
}
|
||||
]
|
||||
- match: { hits.total.value: 4 }
|
||||
- close_to: { hits.hits.0._score: { value: 0.0, error: 0.0001 } }
|
||||
- close_to: { hits.hits.1._score: { value: 0.0, error: 0.0001 } }
|
||||
- close_to: { hits.hits.2._score: { value: 0.0, error: 0.0001 } }
|
||||
- close_to: { hits.hits.3._score: { value: 0.0, error: 0.0001 } }
|
||||
|
||||
---
|
||||
"should throw on unknown normalizer":
|
||||
- do:
|
||||
|
|
Loading…
Reference in New Issue