Normalize negative scores for text_similarity_reranker retriever (#120930)

This commit is contained in:
Panagiotis Bailis 2025-01-28 16:56:47 +02:00 committed by GitHub
parent c8e8ae6e4b
commit 8e2044de15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 105 additions and 25 deletions

View File

@ -0,0 +1,6 @@
pr: 120930
summary: Normalize negative scores for `text_similarity_reranker` retriever
area: Ranking
type: bug
issues:
- 120201

View File

@ -523,6 +523,23 @@ You have the following options:
** Then set up an <<inference-example-eland,{es} service inference endpoint>> with the `rerank` task type.
** Refer to the <<text-similarity-reranker-retriever-example-eland,example>> on this page for a step-by-step guide.
[IMPORTANT]
====
Scores from the re-ranking process are normalized using the following formula before returned to the user,
to avoid having negative scores.
[source,text]
----
score = max(score, 0) + min(exp(score), 1)
----
Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity).
To revert back if needed, one can use:
[source, text]
----
score = score - 1, if score >= 0
score = ln(score), if score < 0
----
====
===== Parameters
`retriever`::

View File

@ -57,6 +57,11 @@ public class RankDocsQuery extends Query {
this.queryNames = queryNames;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
for (RankDoc doc : docs) {
if (false == doc.score >= 0) {
throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?");
}
}
}
@Override
@ -160,7 +165,11 @@ public class RankDocsQuery extends Query {
@Override
public float score() {
return docs[upTo].score;
// We could still end up with a valid 0 score for a RankDoc
// so here we want to differentiate between this and all the tailQuery matches
// that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for
// RankDoc matches.
return Math.max(docs[upTo].score, Float.MIN_VALUE);
}
@Override

View File

@ -251,4 +251,16 @@ public class RankDocsQueryBuilderTests extends AbstractQueryTestCase<RankDocsQue
public void testValidOutput() throws IOException {
// no-op since RankDocsQueryBuilder is an internal only API
}
public void shouldThrowForNegativeScores() throws IOException {
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
iw.addDocument(new Document());
try (IndexReader reader = iw.getReader()) {
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
}
}
}
}

View File

@ -56,6 +56,10 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
protected abstract Collection<Class<? extends Plugin>> pluginsNeeded();
protected boolean shouldCheckScores() {
return true;
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return pluginsNeeded();
@ -95,9 +99,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
int rank = 1;
for (SearchHit searchHit : response.getHits().getHits()) {
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
assertThat(searchHit, hasRank(rank));
assertNotNull(searchHit.getFields().get(searchField));
if (shouldCheckScores()) {
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
}
rank++;
}
}
@ -140,9 +146,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
int rank = 3;
for (SearchHit searchHit : response.getHits().getHits()) {
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
assertThat(searchHit, hasRank(rank));
assertNotNull(searchHit.getFields().get(searchField));
if (shouldCheckScores()) {
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
}
rank++;
}
}
@ -222,9 +230,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
int rank = 1;
for (SearchHit searchHit : response.getHits().getHits()) {
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
assertThat(searchHit, hasRank(rank));
assertNotNull(searchHit.getFields().get(searchField));
if (shouldCheckScores()) {
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
}
rank++;
}
}

View File

@ -26,9 +26,16 @@ import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
public abstract class AbstractTestInferenceService implements InferenceService {
protected static final Random random = new Random(
System.getProperty("tests.seed") == null
? System.currentTimeMillis()
: Long.parseUnsignedLong(System.getProperty("tests.seed").split(":")[0], 16)
);
protected static int stringWeight(String input, int position) {
int hashCode = input.hashCode();
if (hashCode < 0) {

View File

@ -42,6 +42,7 @@ import java.util.List;
import java.util.Map;
public class TestRerankingServiceExtension implements InferenceServiceExtension {
@Override
public List<Factory> getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
@ -149,9 +150,12 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
private RankedDocsResults makeResults(List<String> input) {
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
int totalResults = input.size();
float minScore = random.nextFloat(-1f, 1f);
float resultDiff = 0.2f;
for (int i = 0; i < input.size(); i++) {
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i)));
results.add(
new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i))
);
}
return new RankedDocsResults(results);
}

View File

@ -20,8 +20,8 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
@ -130,10 +130,15 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.filter(doc -> minScore == null || doc.score >= minScore)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
List<RankFeatureDoc> docs = new ArrayList<>();
for (RankFeatureDoc doc : originalDocs) {
if (minScore == null || doc.score >= minScore) {
doc.score = normalizeScore(doc.score);
docs.add(doc);
}
}
docs.sort(RankFeatureDoc::compareTo);
return docs.toArray(new RankFeatureDoc[0]);
}
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
@ -154,7 +159,15 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
}
return scores;
}
private static float normalizeScore(float score) {
// As some models might produce negative scores, we want to ensure that all scores will be positive
// so we will make use of the following normalization formula:
// score = max(score, 0) + min(exp(score), 1)
// this will ensure that all positive scores lie in the [1, inf) range,
// while negative values (and 0) will be shifted to (0, 1]
return Math.max(score, 0) + Math.min((float) Math.exp(score), 1);
}
}

View File

@ -142,6 +142,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
assert scoreDoc.score >= 0;
if (explain) {
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
scoreDoc.doc,

View File

@ -50,4 +50,9 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT {
public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception {
// no-op
}
@Override
protected boolean shouldCheckScores() {
return false;
}
}

View File

@ -131,11 +131,12 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
// Verify order, rank and score of results
SearchHit[] hits = response.getHits().getHits();
assertEquals(5, hits.length);
assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
assertHitHasRankScoreAndText(hits[3], 4, 1.0f, "1");
assertHitHasRankScoreAndText(hits[4], 5, 0.0f, "0");
// we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1");
assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0");
}
);
}
@ -150,9 +151,9 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
// Verify order, rank and score of results
SearchHit[] hits = response.getHits().getHits();
assertEquals(3, hits.length);
assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
}
);
}

View File

@ -20,6 +20,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase {
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.systemProperty("tests.seed", System.getProperty("tests.seed"))
.setting("xpack.security.enabled", "false")
.setting("xpack.security.http.ssl.enabled", "false")
.setting("xpack.license.self_generated.type", "trial")

View File

@ -89,10 +89,7 @@ setup:
- length: { hits.hits: 2 }
- match: { hits.hits.0._id: "doc_2" }
- close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } }
- match: { hits.hits.1._id: "doc_1" }
- close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } }
---
"Simple text similarity rank retriever and filtering":
@ -123,8 +120,6 @@ setup:
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: "doc_1" }
- close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } }
---
"Text similarity reranking fails if the inference ID does not exist":
@ -211,7 +206,6 @@ setup:
- contains: { hits.hits: { _id: "doc_2" } }
- contains: { hits.hits: { _id: "doc_1" } }
- close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } }
- match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" }
- match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" }