Remove matched text from chunks (#123607)

This commit is contained in:
Mike Pellegrini 2025-03-05 15:01:46 -05:00 committed by GitHub
parent f534fc3ccf
commit 2fa6651a68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 119 additions and 140 deletions

View File

@ -21,18 +21,17 @@ public interface ChunkedInference {
* Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields.
*
* @param xcontent provided by the SemanticTextField
* @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding).
* @return an iterator of the serialized {@link Chunk} which includes the offset into the input text and bytes reference
* (output/embedding).
*/
Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException;
Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException;
/**
* A chunk of inference results containing matched text, the substring location
* in the original text and the bytes reference.
* @param matchedText
* A chunk of inference results containing the substring location in the original text and the bytes reference.
* @param textOffset
* @param bytesReference
*/
record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {}
record Chunk(TextOffset textOffset, BytesReference bytesReference) {}
record TextOffset(int start, int end) {}
}

View File

@ -29,7 +29,6 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
List.of(
new SparseEmbeddingResults.Chunk(
sparseEmbeddingResults.embeddings().get(i).tokens(),
inputs.get(i),
new TextOffset(0, inputs.get(i).length())
)
)
@ -41,7 +40,7 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
}
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks()) {
asChunk.add(chunk.toChunk(xcontent));

View File

@ -7,17 +7,16 @@
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;
import java.util.Collections;
import java.util.Iterator;
import java.util.stream.Stream;
public record ChunkedInferenceError(Exception exception) implements ChunkedInference {
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator();
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
return Collections.emptyIterator();
}
}

View File

@ -24,13 +24,11 @@ public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends Em
InferenceServiceResults {
/**
* A resulting embedding together with its input text.
* A resulting embedding together with the offset into the input text.
*/
interface Chunk {
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
String matchedText();
ChunkedInference.TextOffset offset();
}
@ -39,9 +37,9 @@ public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends Em
*/
interface Embedding<C extends Chunk> {
/**
* Combines the resulting embedding with the input into a chunk.
* Combines the resulting embedding with the offset into the input text into a chunk.
*/
C toChunk(String text, ChunkedInference.TextOffset offset);
C toChunk(ChunkedInference.TextOffset offset);
}
/**

View File

@ -175,17 +175,15 @@ public record SparseEmbeddingResults(List<Embedding> embeddings)
}
@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(tokens, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(tokens, offset);
}
}
public record Chunk(List<WeightedToken> weightedTokens, String matchedText, ChunkedInference.TextOffset offset)
implements
EmbeddingResults.Chunk {
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, weightedTokens));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
}
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {

View File

@ -187,18 +187,18 @@ public record TextEmbeddingByteResults(List<Embedding> embeddings)
}
@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(values, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}
/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
public record Chunk(byte[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
}
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {

View File

@ -221,15 +221,15 @@ public record TextEmbeddingFloatResults(List<Embedding> embeddings)
}
@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(values, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}
public record Chunk(float[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
}
/**

View File

@ -183,7 +183,6 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
List.of(
new TextEmbeddingFloatResults.Chunk(
nonChunkedResults.embeddings().get(i).values(),
input.get(i),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)

View File

@ -172,13 +172,7 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
}
results.add(
new ChunkedInferenceEmbedding(
List.of(
new SparseEmbeddingResults.Chunk(
tokens,
input.get(i),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length())))
)
);
}

View File

@ -606,7 +606,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private static class EmptyChunkedInference implements ChunkedInference {
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
return Collections.emptyIterator();
}
}

View File

@ -197,10 +197,7 @@ public class EmbeddingRequestChunker {
AtomicReferenceArray<EmbeddingResults.Embedding<?>> result = results.get(index);
for (int i = 0; i < request.size(); i++) {
EmbeddingResults.Chunk chunk = result.get(i)
.toChunk(
request.get(i).chunkText(),
new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end())
);
.toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end()));
chunks.add(chunk);
}
return new ChunkedInferenceEmbedding(chunks);

View File

@ -275,7 +275,7 @@ public record SemanticTextField(
boolean useLegacyFormat
) throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
}

View File

@ -121,7 +121,6 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
List.of(
new TextEmbeddingFloatResults.Chunk(
textEmbeddingResults.embeddings().get(i).values(),
inputs.getInputs().get(i),
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
)
)

View File

@ -246,7 +246,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@ -275,7 +275,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(1));
assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedFloatResult.chunks().get(0).offset());
}
{
// this is the large input split in multiple chunks
@ -283,26 +283,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(6));
assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
assertThat(chunkedFloatResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
assertThat(chunkedFloatResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
assertThat(chunkedFloatResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
assertThat(chunkedFloatResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
assertThat(chunkedFloatResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
assertThat(chunkedFloatResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
assertThat(chunkedFloatResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
assertThat(chunkedFloatResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
assertThat(chunkedFloatResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
assertThat(chunkedFloatResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
}
{
var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(1));
assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedFloatResult.chunks().get(0).offset());
}
{
var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(1));
assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedFloatResult.chunks().get(0).offset());
}
}
@ -318,7 +318,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@ -347,7 +347,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
}
{
// this is the large input split in multiple chunks
@ -355,26 +355,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(6));
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
}
{
var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
}
{
var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
}
}
@ -390,7 +390,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@ -419,7 +419,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
}
{
// this is the large input split in multiple chunks
@ -427,26 +427,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(6));
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
}
{
var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
}
{
var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
}
}
@ -462,7 +462,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString());
List<String> inputs = List.of("a", "bb", "ccc", passageBuilder.toString());
var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@ -498,21 +498,21 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(1));
assertEquals("1st small", chunkedSparseResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedSparseResult.chunks().get(0).offset());
}
{
var chunkedResult = finalListener.results.get(1);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(1));
assertEquals("2nd small", chunkedSparseResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedSparseResult.chunks().get(0).offset());
}
{
var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(1));
assertEquals("3rd small", chunkedSparseResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedSparseResult.chunks().get(0).offset());
}
{
// this is the large input split in multiple chunks
@ -520,9 +520,9 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each
assertThat(chunkedSparseResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
assertThat(chunkedSparseResult.chunks().get(1).matchedText(), startsWith(" passage_input10 "));
assertThat(chunkedSparseResult.chunks().get(8).matchedText(), startsWith(" passage_input80 "));
assertThat(chunkedSparseResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 149)));
assertThat(chunkedSparseResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(149, 309)));
assertThat(chunkedSparseResult.chunks().get(8).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1350)));
}
}

View File

@ -177,7 +177,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
for (int j = 0; j < values.length; j++) {
values[j] = randomByte();
}
chunks.add(new TextEmbeddingByteResults.Chunk(values, input, new ChunkedInference.TextOffset(0, input.length())));
chunks.add(new TextEmbeddingByteResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length())));
}
return new ChunkedInferenceEmbedding(chunks);
}
@ -189,7 +189,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
for (int j = 0; j < values.length; j++) {
values[j] = randomFloat();
}
chunks.add(new TextEmbeddingFloatResults.Chunk(values, input, new ChunkedInference.TextOffset(0, input.length())));
chunks.add(new TextEmbeddingFloatResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length())));
}
return new ChunkedInferenceEmbedding(chunks);
}
@ -205,7 +205,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
for (var token : input.split("\\s+")) {
tokens.add(new WeightedToken(token, withFloats ? randomFloat() : randomIntBetween(1, 255)));
}
chunks.add(new SparseEmbeddingResults.Chunk(tokens, input, new ChunkedInference.TextOffset(0, input.length())));
chunks.add(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.length())));
}
return new ChunkedInferenceEmbedding(chunks);
}
@ -243,7 +243,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
final List<SemanticTextField.Chunk> chunks = new ArrayList<>(inputs.size());
int offsetAdjustment = 0;
Iterator<String> inputsIt = inputs.iterator();
Iterator<ChunkedInference.Chunk> chunkIt = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
Iterator<ChunkedInference.Chunk> chunkIt = results.chunksAsByteReference(contentType.xContent());
while (inputsIt.hasNext() && chunkIt.hasNext()) {
String input = inputsIt.next();
var chunk = chunkIt.next();
@ -308,7 +308,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
String matchedText = matchedTextIt.next();
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, chunk, matchedText);
var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType());
chunks.add(new SparseEmbeddingResults.Chunk(tokens, matchedText, offset));
chunks.add(new SparseEmbeddingResults.Chunk(tokens, offset));
}
}
return new ChunkedInferenceEmbedding(chunks);
@ -329,7 +329,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
field.inference().modelSettings().dimensions(),
field.contentType()
);
chunks.add(new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), matchedText, offset));
chunks.add(new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), offset));
}
}
return new ChunkedInferenceEmbedding(chunks);

View File

@ -1444,7 +1444,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("abc", "xyz"),
List.of("a", "bb"),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1457,7 +1457,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("abc", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.123F, 0.678F },
@ -1469,7 +1469,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("xyz", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.223F, 0.278F },

View File

@ -1191,7 +1191,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1204,7 +1204,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("foo", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.0123f, -0.0123f },
@ -1216,7 +1216,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("bar", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 1.0123f, -1.0123f },
@ -1232,7 +1232,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), Matchers.is(2));
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
assertThat(requestMap.get("user"), Matchers.is("user"));
}
}

View File

@ -1341,7 +1341,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1354,7 +1354,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("foo", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
@ -1366,7 +1366,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("bar", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 1.123f, -1.123f },
@ -1382,7 +1382,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), Matchers.is(2));
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
assertThat(requestMap.get("user"), Matchers.is("user"));
}
}

View File

@ -1452,7 +1452,7 @@ public class CohereServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1465,7 +1465,7 @@ public class CohereServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("foo", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertArrayEquals(
new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
@ -1476,7 +1476,7 @@ public class CohereServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("bar", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertArrayEquals(
new float[] { 0.223f, -0.223f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
@ -1495,7 +1495,7 @@ public class CohereServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
MatcherAssert.assertThat(
requestMap,
is(Map.of("texts", List.of("foo", "bar"), "model", "model", "embedding_types", List.of("float")))
is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float")))
);
}
}
@ -1551,7 +1551,7 @@ public class CohereServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1564,7 +1564,7 @@ public class CohereServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var byteResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(byteResult.chunks(), hasSize(1));
assertEquals("foo", byteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset());
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
assertArrayEquals(new byte[] { 23, -23 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding());
}
@ -1572,7 +1572,7 @@ public class CohereServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var byteResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(byteResult.chunks(), hasSize(1));
assertEquals("bar", byteResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset());
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
assertArrayEquals(new byte[] { 24, -24 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding());
}
@ -1588,7 +1588,7 @@ public class CohereServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
MatcherAssert.assertThat(
requestMap,
is(Map.of("texts", List.of("foo", "bar"), "model", "model", "embedding_types", List.of("int8")))
is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8")))
);
}
}

View File

@ -566,7 +566,6 @@ public class ElasticInferenceServiceTests extends ESTestCase {
List.of(
new SparseEmbeddingResults.Chunk(
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
"input text",
new ChunkedInference.TextOffset(0, "input text".length())
)
)

View File

@ -902,7 +902,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
((TextEmbeddingFloatResults.Chunk) result1.chunks().get(0)).embedding(),
0.0001f
);
assertEquals("foo", result1.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks(), hasSize(1));
@ -912,7 +912,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
((TextEmbeddingFloatResults.Chunk) result2.chunks().get(0)).embedding(),
0.0001f
);
assertEquals("bar", result2.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
gotResults.set(true);
}, ESTestCase::fail);
@ -923,7 +923,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
Map.of(),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -977,7 +977,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens()
);
assertEquals("foo", result1.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
@ -985,7 +985,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens()
);
assertEquals("bar", result2.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
gotResults.set(true);
}, ESTestCase::fail);
@ -995,7 +995,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
Map.of(),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1049,7 +1049,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens()
);
assertEquals("foo", result1.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
@ -1057,7 +1057,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens()
);
assertEquals("bar", result2.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
gotResults.set(true);
}, ESTestCase::fail);
@ -1067,7 +1067,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
Map.of(),
InputType.SEARCH,
InferenceAction.Request.DEFAULT_TIMEOUT,

View File

@ -844,7 +844,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException {
var input = List.of("foo", "bar");
var input = List.of("a", "bb");
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@ -881,7 +881,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertTrue(
Arrays.equals(
@ -896,7 +896,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertTrue(
Arrays.equals(

View File

@ -111,7 +111,6 @@ public class HuggingFaceElserServiceTests extends ESTestCase {
List.of(
new SparseEmbeddingResults.Chunk(
List.of(new WeightedToken(".", 0.13315596f)),
"abc",
new ChunkedInference.TextOffset(0, "abc".length())
)
)

View File

@ -787,7 +787,6 @@ public class HuggingFaceServiceTests extends ESTestCase {
assertThat(result, CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var embeddingResult = (ChunkedInferenceEmbedding) result;
assertThat(embeddingResult.chunks(), hasSize(1));
assertThat(embeddingResult.chunks().get(0).matchedText(), is("abc"));
assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length())));
assertThat(embeddingResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
@ -842,7 +841,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("abc", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },

View File

@ -686,7 +686,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
}
private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException {
var input = List.of("foo", "bar");
var input = List.of("a", "bb");
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@ -733,7 +733,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertTrue(
Arrays.equals(
@ -748,7 +748,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertTrue(
Arrays.equals(
@ -763,7 +763,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("foo", "bar"), "model_id", modelId)));
assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("a", "bb"), "model_id", modelId)));
}
}

View File

@ -1819,7 +1819,7 @@ public class JinaAIServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1832,7 +1832,7 @@ public class JinaAIServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("foo", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
@ -1844,7 +1844,7 @@ public class JinaAIServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("bar", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.223f, -0.223f },
@ -1864,7 +1864,7 @@ public class JinaAIServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
MatcherAssert.assertThat(
requestMap,
is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2", "embedding_type", "float"))
is(Map.of("input", List.of("a", "bb"), "model", "jina-clip-v2", "embedding_type", "float"))
);
}
}

View File

@ -1857,7 +1857,7 @@ public class OpenAiServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1870,7 +1870,7 @@ public class OpenAiServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("foo", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertTrue(
Arrays.equals(
@ -1883,7 +1883,7 @@ public class OpenAiServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("bar", floatResult.chunks().get(0).matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertTrue(
Arrays.equals(
@ -1901,7 +1901,7 @@ public class OpenAiServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), Matchers.is(3));
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
assertThat(requestMap.get("model"), Matchers.is("model"));
assertThat(requestMap.get("user"), Matchers.is("user"));
}

View File

@ -1826,7 +1826,7 @@ public class VoyageAIServiceTests extends ESTestCase {
service.chunkedInfer(
model,
null,
List.of("foo", "bar"),
List.of("a", "bb"),
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
@ -1839,7 +1839,7 @@ public class VoyageAIServiceTests extends ESTestCase {
assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("foo", floatResult.chunks().getFirst().matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset());
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
@ -1851,7 +1851,7 @@ public class VoyageAIServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals("bar", floatResult.chunks().getFirst().matchedText());
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset());
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
assertArrayEquals(
new float[] { 0.223f, -0.223f },
@ -1871,7 +1871,7 @@ public class VoyageAIServiceTests extends ESTestCase {
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
MatcherAssert.assertThat(
requestMap,
is(Map.of("input", List.of("foo", "bar"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024))
is(Map.of("input", List.of("a", "bb"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024))
);
}
}