Remove matched text from chunks (#123607)
This commit is contained in:
parent
f534fc3ccf
commit
2fa6651a68
|
@ -21,18 +21,17 @@ public interface ChunkedInference {
|
|||
* Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields.
|
||||
*
|
||||
* @param xcontent provided by the SemanticTextField
|
||||
* @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding).
|
||||
* @return an iterator of the serialized {@link Chunk} which includes the offset into the input text and bytes reference
|
||||
* (output/embedding).
|
||||
*/
|
||||
Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException;
|
||||
Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException;
|
||||
|
||||
/**
|
||||
* A chunk of inference results containing matched text, the substring location
|
||||
* in the original text and the bytes reference.
|
||||
* @param matchedText
|
||||
* A chunk of inference results containing the substring location in the original text and the bytes reference.
|
||||
* @param textOffset
|
||||
* @param bytesReference
|
||||
*/
|
||||
record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {}
|
||||
record Chunk(TextOffset textOffset, BytesReference bytesReference) {}
|
||||
|
||||
record TextOffset(int start, int end) {}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,6 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
|
|||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
sparseEmbeddingResults.embeddings().get(i).tokens(),
|
||||
inputs.get(i),
|
||||
new TextOffset(0, inputs.get(i).length())
|
||||
)
|
||||
)
|
||||
|
@ -41,7 +40,7 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
|
|||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
|
||||
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
|
||||
var asChunk = new ArrayList<Chunk>();
|
||||
for (var chunk : chunks()) {
|
||||
asChunk.add(chunk.toChunk(xcontent));
|
||||
|
|
|
@ -7,17 +7,16 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.xcontent.XContent;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public record ChunkedInferenceError(Exception exception) implements ChunkedInference {
|
||||
|
||||
@Override
|
||||
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
|
||||
return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator();
|
||||
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
|
||||
return Collections.emptyIterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,13 +24,11 @@ public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends Em
|
|||
InferenceServiceResults {
|
||||
|
||||
/**
|
||||
* A resulting embedding together with its input text.
|
||||
* A resulting embedding together with the offset into the input text.
|
||||
*/
|
||||
interface Chunk {
|
||||
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
|
||||
|
||||
String matchedText();
|
||||
|
||||
ChunkedInference.TextOffset offset();
|
||||
}
|
||||
|
||||
|
@ -39,9 +37,9 @@ public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends Em
|
|||
*/
|
||||
interface Embedding<C extends Chunk> {
|
||||
/**
|
||||
* Combines the resulting embedding with the input into a chunk.
|
||||
* Combines the resulting embedding with the offset into the input text into a chunk.
|
||||
*/
|
||||
C toChunk(String text, ChunkedInference.TextOffset offset);
|
||||
C toChunk(ChunkedInference.TextOffset offset);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -175,17 +175,15 @@ public record SparseEmbeddingResults(List<Embedding> embeddings)
|
|||
}
|
||||
|
||||
@Override
|
||||
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(tokens, text, offset);
|
||||
public Chunk toChunk(ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(tokens, offset);
|
||||
}
|
||||
}
|
||||
|
||||
public record Chunk(List<WeightedToken> weightedTokens, String matchedText, ChunkedInference.TextOffset offset)
|
||||
implements
|
||||
EmbeddingResults.Chunk {
|
||||
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
|
||||
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
|
||||
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, weightedTokens));
|
||||
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
|
||||
}
|
||||
|
||||
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
|
||||
|
|
|
@ -187,18 +187,18 @@ public record TextEmbeddingByteResults(List<Embedding> embeddings)
|
|||
}
|
||||
|
||||
@Override
|
||||
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(values, text, offset);
|
||||
public Chunk toChunk(ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(values, offset);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
|
||||
*/
|
||||
public record Chunk(byte[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
|
||||
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
|
||||
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
|
||||
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
|
||||
}
|
||||
|
||||
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
|
||||
|
|
|
@ -221,15 +221,15 @@ public record TextEmbeddingFloatResults(List<Embedding> embeddings)
|
|||
}
|
||||
|
||||
@Override
|
||||
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(values, text, offset);
|
||||
public Chunk toChunk(ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(values, offset);
|
||||
}
|
||||
}
|
||||
|
||||
public record Chunk(float[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
|
||||
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
|
||||
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
|
||||
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -183,7 +183,6 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
|
|||
List.of(
|
||||
new TextEmbeddingFloatResults.Chunk(
|
||||
nonChunkedResults.embeddings().get(i).values(),
|
||||
input.get(i),
|
||||
new ChunkedInference.TextOffset(0, input.get(i).length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -172,13 +172,7 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
|
|||
}
|
||||
results.add(
|
||||
new ChunkedInferenceEmbedding(
|
||||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
tokens,
|
||||
input.get(i),
|
||||
new ChunkedInference.TextOffset(0, input.get(i).length())
|
||||
)
|
||||
)
|
||||
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length())))
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -606,7 +606,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
|
||||
private static class EmptyChunkedInference implements ChunkedInference {
|
||||
@Override
|
||||
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
|
||||
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
|
||||
return Collections.emptyIterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -197,10 +197,7 @@ public class EmbeddingRequestChunker {
|
|||
AtomicReferenceArray<EmbeddingResults.Embedding<?>> result = results.get(index);
|
||||
for (int i = 0; i < request.size(); i++) {
|
||||
EmbeddingResults.Chunk chunk = result.get(i)
|
||||
.toChunk(
|
||||
request.get(i).chunkText(),
|
||||
new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end())
|
||||
);
|
||||
.toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end()));
|
||||
chunks.add(chunk);
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
|
|
|
@ -275,7 +275,7 @@ public record SemanticTextField(
|
|||
boolean useLegacyFormat
|
||||
) throws IOException {
|
||||
List<Chunk> chunks = new ArrayList<>();
|
||||
Iterator<ChunkedInference.Chunk> it = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
|
||||
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
|
||||
while (it.hasNext()) {
|
||||
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
|
||||
}
|
||||
|
|
|
@ -121,7 +121,6 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
|
|||
List.of(
|
||||
new TextEmbeddingFloatResults.Chunk(
|
||||
textEmbeddingResults.embeddings().get(i).values(),
|
||||
inputs.getInputs().get(i),
|
||||
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -246,7 +246,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
|
@ -275,7 +275,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(1));
|
||||
assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedFloatResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -283,26 +283,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(6));
|
||||
assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
|
||||
assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
|
||||
assertThat(chunkedFloatResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
|
||||
assertThat(chunkedFloatResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
|
||||
assertThat(chunkedFloatResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
|
||||
assertThat(chunkedFloatResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
|
||||
assertThat(chunkedFloatResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
|
||||
assertThat(chunkedFloatResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
|
||||
assertThat(chunkedFloatResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
|
||||
assertThat(chunkedFloatResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
|
||||
assertThat(chunkedFloatResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
|
||||
assertThat(chunkedFloatResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(1));
|
||||
assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedFloatResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(3);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(1));
|
||||
assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedFloatResult.chunks().get(0).offset());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
|
@ -347,7 +347,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -355,26 +355,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(6));
|
||||
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
|
||||
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
|
||||
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
|
||||
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
|
||||
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
|
||||
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
|
||||
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
|
||||
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
|
||||
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
|
||||
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
|
||||
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
|
||||
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(3);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -390,7 +390,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
|
@ -419,7 +419,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -427,26 +427,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(6));
|
||||
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
|
||||
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
|
||||
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
|
||||
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
|
||||
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
|
||||
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
|
||||
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
|
||||
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
|
||||
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
|
||||
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
|
||||
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
|
||||
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(3);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -462,7 +462,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString());
|
||||
List<String> inputs = List.of("a", "bb", "ccc", passageBuilder.toString());
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
|
@ -498,21 +498,21 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(1));
|
||||
assertEquals("1st small", chunkedSparseResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedSparseResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(1);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(1));
|
||||
assertEquals("2nd small", chunkedSparseResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedSparseResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(1));
|
||||
assertEquals("3rd small", chunkedSparseResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedSparseResult.chunks().get(0).offset());
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -520,9 +520,9 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each
|
||||
assertThat(chunkedSparseResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
|
||||
assertThat(chunkedSparseResult.chunks().get(1).matchedText(), startsWith(" passage_input10 "));
|
||||
assertThat(chunkedSparseResult.chunks().get(8).matchedText(), startsWith(" passage_input80 "));
|
||||
assertThat(chunkedSparseResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 149)));
|
||||
assertThat(chunkedSparseResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(149, 309)));
|
||||
assertThat(chunkedSparseResult.chunks().get(8).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1350)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
for (int j = 0; j < values.length; j++) {
|
||||
values[j] = randomByte();
|
||||
}
|
||||
chunks.add(new TextEmbeddingByteResults.Chunk(values, input, new ChunkedInference.TextOffset(0, input.length())));
|
||||
chunks.add(new TextEmbeddingByteResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length())));
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
for (int j = 0; j < values.length; j++) {
|
||||
values[j] = randomFloat();
|
||||
}
|
||||
chunks.add(new TextEmbeddingFloatResults.Chunk(values, input, new ChunkedInference.TextOffset(0, input.length())));
|
||||
chunks.add(new TextEmbeddingFloatResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length())));
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
@ -205,7 +205,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
for (var token : input.split("\\s+")) {
|
||||
tokens.add(new WeightedToken(token, withFloats ? randomFloat() : randomIntBetween(1, 255)));
|
||||
}
|
||||
chunks.add(new SparseEmbeddingResults.Chunk(tokens, input, new ChunkedInference.TextOffset(0, input.length())));
|
||||
chunks.add(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.length())));
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
final List<SemanticTextField.Chunk> chunks = new ArrayList<>(inputs.size());
|
||||
int offsetAdjustment = 0;
|
||||
Iterator<String> inputsIt = inputs.iterator();
|
||||
Iterator<ChunkedInference.Chunk> chunkIt = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
|
||||
Iterator<ChunkedInference.Chunk> chunkIt = results.chunksAsByteReference(contentType.xContent());
|
||||
while (inputsIt.hasNext() && chunkIt.hasNext()) {
|
||||
String input = inputsIt.next();
|
||||
var chunk = chunkIt.next();
|
||||
|
@ -308,7 +308,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
String matchedText = matchedTextIt.next();
|
||||
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, chunk, matchedText);
|
||||
var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType());
|
||||
chunks.add(new SparseEmbeddingResults.Chunk(tokens, matchedText, offset));
|
||||
chunks.add(new SparseEmbeddingResults.Chunk(tokens, offset));
|
||||
}
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
|
@ -329,7 +329,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
field.inference().modelSettings().dimensions(),
|
||||
field.contentType()
|
||||
);
|
||||
chunks.add(new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), matchedText, offset));
|
||||
chunks.add(new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), offset));
|
||||
}
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
|
|
|
@ -1444,7 +1444,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("abc", "xyz"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1457,7 +1457,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("abc", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123F, 0.678F },
|
||||
|
@ -1469,7 +1469,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("xyz", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223F, 0.278F },
|
||||
|
|
|
@ -1191,7 +1191,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1204,7 +1204,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.0123f, -0.0123f },
|
||||
|
@ -1216,7 +1216,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 1.0123f, -1.0123f },
|
||||
|
@ -1232,7 +1232,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), Matchers.is(2));
|
||||
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
|
||||
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
|
||||
assertThat(requestMap.get("user"), Matchers.is("user"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1341,7 +1341,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1354,7 +1354,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
|
@ -1366,7 +1366,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 1.123f, -1.123f },
|
||||
|
@ -1382,7 +1382,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), Matchers.is(2));
|
||||
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
|
||||
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
|
||||
assertThat(requestMap.get("user"), Matchers.is("user"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1452,7 +1452,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.UNSPECIFIED,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1465,7 +1465,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
|
@ -1476,7 +1476,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
|
@ -1495,7 +1495,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(Map.of("texts", List.of("foo", "bar"), "model", "model", "embedding_types", List.of("float")))
|
||||
is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float")))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1551,7 +1551,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.UNSPECIFIED,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1564,7 +1564,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var byteResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(byteResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", byteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset());
|
||||
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
|
||||
assertArrayEquals(new byte[] { 23, -23 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding());
|
||||
}
|
||||
|
@ -1572,7 +1572,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var byteResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(byteResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", byteResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset());
|
||||
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
|
||||
assertArrayEquals(new byte[] { 24, -24 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding());
|
||||
}
|
||||
|
@ -1588,7 +1588,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(Map.of("texts", List.of("foo", "bar"), "model", "model", "embedding_types", List.of("int8")))
|
||||
is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8")))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -566,7 +566,6 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
|||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
||||
"input text",
|
||||
new ChunkedInference.TextOffset(0, "input text".length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -902,7 +902,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
((TextEmbeddingFloatResults.Chunk) result1.chunks().get(0)).embedding(),
|
||||
0.0001f
|
||||
);
|
||||
assertEquals("foo", result1.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
|
||||
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
|
||||
assertThat(result2.chunks(), hasSize(1));
|
||||
|
@ -912,7 +912,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
((TextEmbeddingFloatResults.Chunk) result2.chunks().get(0)).embedding(),
|
||||
0.0001f
|
||||
);
|
||||
assertEquals("bar", result2.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
|
||||
|
||||
gotResults.set(true);
|
||||
}, ESTestCase::fail);
|
||||
|
@ -923,7 +923,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
Map.of(),
|
||||
InputType.SEARCH,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -977,7 +977,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens()
|
||||
);
|
||||
assertEquals("foo", result1.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
|
||||
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
|
||||
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
|
||||
|
@ -985,7 +985,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens()
|
||||
);
|
||||
assertEquals("bar", result2.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
|
||||
gotResults.set(true);
|
||||
}, ESTestCase::fail);
|
||||
|
||||
|
@ -995,7 +995,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
Map.of(),
|
||||
InputType.SEARCH,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1049,7 +1049,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens()
|
||||
);
|
||||
assertEquals("foo", result1.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
|
||||
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
|
||||
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
|
||||
|
@ -1057,7 +1057,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens()
|
||||
);
|
||||
assertEquals("bar", result2.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
|
||||
gotResults.set(true);
|
||||
}, ESTestCase::fail);
|
||||
|
||||
|
@ -1067,7 +1067,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
Map.of(),
|
||||
InputType.SEARCH,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
|
|
@ -844,7 +844,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
|
||||
private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException {
|
||||
|
||||
var input = List.of("foo", "bar");
|
||||
var input = List.of("a", "bb");
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
|
@ -881,7 +881,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
|
@ -896,7 +896,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
|
|
|
@ -111,7 +111,6 @@ public class HuggingFaceElserServiceTests extends ESTestCase {
|
|||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
List.of(new WeightedToken(".", 0.13315596f)),
|
||||
"abc",
|
||||
new ChunkedInference.TextOffset(0, "abc".length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -787,7 +787,6 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
assertThat(result, CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var embeddingResult = (ChunkedInferenceEmbedding) result;
|
||||
assertThat(embeddingResult.chunks(), hasSize(1));
|
||||
assertThat(embeddingResult.chunks().get(0).matchedText(), is("abc"));
|
||||
assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length())));
|
||||
assertThat(embeddingResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
|
@ -842,7 +841,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("abc", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
|
|
|
@ -686,7 +686,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException {
|
||||
var input = List.of("foo", "bar");
|
||||
var input = List.of("a", "bb");
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
|
@ -733,7 +733,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
|
@ -748,7 +748,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
|
@ -763,7 +763,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap, aMapWithSize(3));
|
||||
assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("foo", "bar"), "model_id", modelId)));
|
||||
assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("a", "bb"), "model_id", modelId)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1819,7 +1819,7 @@ public class JinaAIServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.UNSPECIFIED,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1832,7 +1832,7 @@ public class JinaAIServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
|
@ -1844,7 +1844,7 @@ public class JinaAIServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
|
@ -1864,7 +1864,7 @@ public class JinaAIServiceTests extends ESTestCase {
|
|||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2", "embedding_type", "float"))
|
||||
is(Map.of("input", List.of("a", "bb"), "model", "jina-clip-v2", "embedding_type", "float"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1857,7 +1857,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1870,7 +1870,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
|
@ -1883,7 +1883,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", floatResult.chunks().get(0).matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
|
@ -1901,7 +1901,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), Matchers.is(3));
|
||||
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
|
||||
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
|
||||
assertThat(requestMap.get("model"), Matchers.is("model"));
|
||||
assertThat(requestMap.get("user"), Matchers.is("user"));
|
||||
}
|
||||
|
|
|
@ -1826,7 +1826,7 @@ public class VoyageAIServiceTests extends ESTestCase {
|
|||
service.chunkedInfer(
|
||||
model,
|
||||
null,
|
||||
List.of("foo", "bar"),
|
||||
List.of("a", "bb"),
|
||||
new HashMap<>(),
|
||||
InputType.UNSPECIFIED,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
|
@ -1839,7 +1839,7 @@ public class VoyageAIServiceTests extends ESTestCase {
|
|||
assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.getFirst();
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("foo", floatResult.chunks().getFirst().matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset());
|
||||
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
|
@ -1851,7 +1851,7 @@ public class VoyageAIServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals("bar", floatResult.chunks().getFirst().matchedText());
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset());
|
||||
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
|
@ -1871,7 +1871,7 @@ public class VoyageAIServiceTests extends ESTestCase {
|
|||
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(Map.of("input", List.of("foo", "bar"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024))
|
||||
is(Map.of("input", List.of("a", "bb"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue