Normalize negative scores for text_similarity_reranker retriever (#120930)
This commit is contained in:
parent
c8e8ae6e4b
commit
8e2044de15
|
@ -0,0 +1,6 @@
|
|||
pr: 120930
|
||||
summary: Normalize negative scores for `text_similarity_reranker` retriever
|
||||
area: Ranking
|
||||
type: bug
|
||||
issues:
|
||||
- 120201
|
|
@ -523,6 +523,23 @@ You have the following options:
|
|||
** Then set up an <<inference-example-eland,{es} service inference endpoint>> with the `rerank` task type.
|
||||
** Refer to the <<text-similarity-reranker-retriever-example-eland,example>> on this page for a step-by-step guide.
|
||||
|
||||
[IMPORTANT]
|
||||
====
|
||||
Scores from the re-ranking process are normalized using the following formula before returned to the user,
|
||||
to avoid having negative scores.
|
||||
[source,text]
|
||||
----
|
||||
score = max(score, 0) + min(exp(score), 1)
|
||||
----
|
||||
Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity).
|
||||
To revert back if needed, one can use:
|
||||
[source, text]
|
||||
----
|
||||
score = score - 1, if score >= 0
|
||||
score = ln(score), if score < 0
|
||||
----
|
||||
====
|
||||
|
||||
===== Parameters
|
||||
|
||||
`retriever`::
|
||||
|
|
|
@ -57,6 +57,11 @@ public class RankDocsQuery extends Query {
|
|||
this.queryNames = queryNames;
|
||||
this.segmentStarts = segmentStarts;
|
||||
this.contextIdentity = contextIdentity;
|
||||
for (RankDoc doc : docs) {
|
||||
if (false == doc.score >= 0) {
|
||||
throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -160,7 +165,11 @@ public class RankDocsQuery extends Query {
|
|||
|
||||
@Override
|
||||
public float score() {
|
||||
return docs[upTo].score;
|
||||
// We could still end up with a valid 0 score for a RankDoc
|
||||
// so here we want to differentiate between this and all the tailQuery matches
|
||||
// that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for
|
||||
// RankDoc matches.
|
||||
return Math.max(docs[upTo].score, Float.MIN_VALUE);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -251,4 +251,16 @@ public class RankDocsQueryBuilderTests extends AbstractQueryTestCase<RankDocsQue
|
|||
public void testValidOutput() throws IOException {
|
||||
// no-op since RankDocsQueryBuilder is an internal only API
|
||||
}
|
||||
|
||||
public void shouldThrowForNegativeScores() throws IOException {
|
||||
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
|
||||
iw.addDocument(new Document());
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
|
||||
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
|
||||
assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,6 +56,10 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
|
|||
|
||||
protected abstract Collection<Class<? extends Plugin>> pluginsNeeded();
|
||||
|
||||
protected boolean shouldCheckScores() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Collection<Class<? extends Plugin>> nodePlugins() {
|
||||
return pluginsNeeded();
|
||||
|
@ -95,9 +99,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
|
|||
int rank = 1;
|
||||
for (SearchHit searchHit : response.getHits().getHits()) {
|
||||
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
|
||||
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
|
||||
assertThat(searchHit, hasRank(rank));
|
||||
assertNotNull(searchHit.getFields().get(searchField));
|
||||
if (shouldCheckScores()) {
|
||||
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
|
||||
}
|
||||
rank++;
|
||||
}
|
||||
}
|
||||
|
@ -140,9 +146,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
|
|||
int rank = 3;
|
||||
for (SearchHit searchHit : response.getHits().getHits()) {
|
||||
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
|
||||
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
|
||||
assertThat(searchHit, hasRank(rank));
|
||||
assertNotNull(searchHit.getFields().get(searchField));
|
||||
if (shouldCheckScores()) {
|
||||
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
|
||||
}
|
||||
rank++;
|
||||
}
|
||||
}
|
||||
|
@ -222,9 +230,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
|
|||
int rank = 1;
|
||||
for (SearchHit searchHit : response.getHits().getHits()) {
|
||||
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
|
||||
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
|
||||
assertThat(searchHit, hasRank(rank));
|
||||
assertNotNull(searchHit.getFields().get(searchField));
|
||||
if (shouldCheckScores()) {
|
||||
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
|
||||
}
|
||||
rank++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,9 +26,16 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
|||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
public abstract class AbstractTestInferenceService implements InferenceService {
|
||||
|
||||
protected static final Random random = new Random(
|
||||
System.getProperty("tests.seed") == null
|
||||
? System.currentTimeMillis()
|
||||
: Long.parseUnsignedLong(System.getProperty("tests.seed").split(":")[0], 16)
|
||||
);
|
||||
|
||||
protected static int stringWeight(String input, int position) {
|
||||
int hashCode = input.hashCode();
|
||||
if (hashCode < 0) {
|
||||
|
|
|
@ -42,6 +42,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
public class TestRerankingServiceExtension implements InferenceServiceExtension {
|
||||
|
||||
@Override
|
||||
public List<Factory> getInferenceServiceFactories() {
|
||||
return List.of(TestInferenceService::new);
|
||||
|
@ -149,9 +150,12 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
|||
private RankedDocsResults makeResults(List<String> input) {
|
||||
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
|
||||
int totalResults = input.size();
|
||||
float minScore = random.nextFloat(-1f, 1f);
|
||||
float resultDiff = 0.2f;
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i)));
|
||||
results.add(
|
||||
new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i))
|
||||
);
|
||||
}
|
||||
return new RankedDocsResults(results);
|
||||
}
|
||||
|
|
|
@ -20,8 +20,8 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
|||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -130,10 +130,15 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
|
|||
*/
|
||||
@Override
|
||||
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
|
||||
return Arrays.stream(originalDocs)
|
||||
.filter(doc -> minScore == null || doc.score >= minScore)
|
||||
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
|
||||
.toArray(RankFeatureDoc[]::new);
|
||||
List<RankFeatureDoc> docs = new ArrayList<>();
|
||||
for (RankFeatureDoc doc : originalDocs) {
|
||||
if (minScore == null || doc.score >= minScore) {
|
||||
doc.score = normalizeScore(doc.score);
|
||||
docs.add(doc);
|
||||
}
|
||||
}
|
||||
docs.sort(RankFeatureDoc::compareTo);
|
||||
return docs.toArray(new RankFeatureDoc[0]);
|
||||
}
|
||||
|
||||
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
|
||||
|
@ -154,7 +159,15 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
|
|||
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
|
||||
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
|
||||
}
|
||||
|
||||
return scores;
|
||||
}
|
||||
|
||||
private static float normalizeScore(float score) {
|
||||
// As some models might produce negative scores, we want to ensure that all scores will be positive
|
||||
// so we will make use of the following normalization formula:
|
||||
// score = max(score, 0) + min(exp(score), 1)
|
||||
// this will ensure that all positive scores lie in the [1, inf) range,
|
||||
// while negative values (and 0) will be shifted to (0, 1]
|
||||
return Math.max(score, 0) + Math.min((float) Math.exp(score), 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -142,6 +142,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
|
|||
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
|
||||
for (int i = 0; i < scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = scoreDocs[i];
|
||||
assert scoreDoc.score >= 0;
|
||||
if (explain) {
|
||||
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
|
||||
scoreDoc.doc,
|
||||
|
|
|
@ -50,4 +50,9 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT {
|
|||
public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception {
|
||||
// no-op
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean shouldCheckScores() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -131,11 +131,12 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|||
// Verify order, rank and score of results
|
||||
SearchHit[] hits = response.getHits().getHits();
|
||||
assertEquals(5, hits.length);
|
||||
assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
|
||||
assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
|
||||
assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
|
||||
assertHitHasRankScoreAndText(hits[3], 4, 1.0f, "1");
|
||||
assertHitHasRankScoreAndText(hits[4], 5, 0.0f, "0");
|
||||
// we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1
|
||||
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
|
||||
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
|
||||
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
|
||||
assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1");
|
||||
assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0");
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -150,9 +151,9 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|||
// Verify order, rank and score of results
|
||||
SearchHit[] hits = response.getHits().getHits();
|
||||
assertEquals(3, hits.length);
|
||||
assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
|
||||
assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
|
||||
assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
|
||||
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
|
||||
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
|
||||
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase {
|
|||
|
||||
@ClassRule
|
||||
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
|
||||
.systemProperty("tests.seed", System.getProperty("tests.seed"))
|
||||
.setting("xpack.security.enabled", "false")
|
||||
.setting("xpack.security.http.ssl.enabled", "false")
|
||||
.setting("xpack.license.self_generated.type", "trial")
|
||||
|
|
|
@ -89,10 +89,7 @@ setup:
|
|||
- length: { hits.hits: 2 }
|
||||
|
||||
- match: { hits.hits.0._id: "doc_2" }
|
||||
- close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } }
|
||||
|
||||
- match: { hits.hits.1._id: "doc_1" }
|
||||
- close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } }
|
||||
|
||||
---
|
||||
"Simple text similarity rank retriever and filtering":
|
||||
|
@ -123,8 +120,6 @@ setup:
|
|||
- length: { hits.hits: 1 }
|
||||
|
||||
- match: { hits.hits.0._id: "doc_1" }
|
||||
- close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } }
|
||||
|
||||
|
||||
---
|
||||
"Text similarity reranking fails if the inference ID does not exist":
|
||||
|
@ -211,7 +206,6 @@ setup:
|
|||
- contains: { hits.hits: { _id: "doc_2" } }
|
||||
- contains: { hits.hits: { _id: "doc_1" } }
|
||||
|
||||
- close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } }
|
||||
- match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" }
|
||||
- match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" }
|
||||
|
||||
|
|
Loading…
Reference in New Issue