Adding support for binary embedding type to Cohere service embedding type (#120751)
* Adding support for binary embedding type to Cohere service embedding type * Returning response in separate text_embedding_bits field * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java Co-authored-by: David Kyle <david.kyle@elastic.co> * Update docs/changelog/120751.yaml * Reverting docs change --------- Co-authored-by: David Kyle <david.kyle@elastic.co>
This commit is contained in:
parent
92d1d31eea
commit
89d71e1f6c
|
@ -0,0 +1,5 @@
|
|||
pr: 120751
|
||||
summary: Adding support for binary embedding type to Cohere service embedding type
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -173,7 +173,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00);
|
||||
public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00);
|
||||
|
||||
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*
|
||||
* this file was contributed to by a generative AI
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
|
||||
public static final String EMBEDDING = "embedding";
|
||||
|
||||
public InferenceByteEmbedding(StreamInput in) throws IOException {
|
||||
this(in.readByteArray());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeByteArray(values);
|
||||
}
|
||||
|
||||
public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
|
||||
byte[] embeddingValues = new byte[embeddingValuesList.size()];
|
||||
for (int i = 0; i < embeddingValuesList.size(); i++) {
|
||||
embeddingValues[i] = embeddingValuesList.get(i);
|
||||
}
|
||||
return new InferenceByteEmbedding(embeddingValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder.startArray(EMBEDDING);
|
||||
for (byte value : values) {
|
||||
builder.value(value);
|
||||
}
|
||||
builder.endArray();
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
|
||||
float[] toFloatArray() {
|
||||
float[] floatArray = new float[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
floatArray[i] = ((Byte) values[i]).floatValue();
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
|
||||
double[] toDoubleArray() {
|
||||
double[] doubleArray = new double[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
doubleArray[i] = ((Byte) values[i]).doubleValue();
|
||||
}
|
||||
return doubleArray;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSize() {
|
||||
return values().length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
|
||||
return Arrays.equals(values, embedding.values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Arrays.hashCode(values);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*
|
||||
* this file was contributed to by a generative AI
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Writes a text embedding result in the follow json format
|
||||
* {
|
||||
* "text_embedding_bytes": [
|
||||
* {
|
||||
* "embedding": [
|
||||
* 23
|
||||
* ]
|
||||
* },
|
||||
* {
|
||||
* "embedding": [
|
||||
* -23
|
||||
* ]
|
||||
* }
|
||||
* ]
|
||||
* }
|
||||
*/
|
||||
public record InferenceTextEmbeddingBitResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
|
||||
public static final String NAME = "text_embedding_service_bit_results";
|
||||
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
|
||||
|
||||
public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
|
||||
this(in.readCollectionAsList(InferenceByteEmbedding::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getFirstEmbeddingSize() {
|
||||
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
|
||||
return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeCollection(embeddings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<? extends InferenceResults> transformToCoordinationFormat() {
|
||||
return embeddings.stream()
|
||||
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("deprecation")
|
||||
public List<? extends InferenceResults> transformToLegacyFormat() {
|
||||
var legacyEmbedding = new LegacyTextEmbeddingResults(
|
||||
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
|
||||
);
|
||||
|
||||
return List.of(legacyEmbedding);
|
||||
}
|
||||
|
||||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(TEXT_EMBEDDING_BITS, embeddings);
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o;
|
||||
return Objects.equals(embeddings, that.embeddings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(embeddings);
|
||||
}
|
||||
}
|
|
@ -9,21 +9,16 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
|
@ -33,7 +28,7 @@ import java.util.Objects;
|
|||
/**
|
||||
* Writes a text embedding result in the follow json format
|
||||
* {
|
||||
* "text_embedding": [
|
||||
* "text_embedding_bytes": [
|
||||
* {
|
||||
* "embedding": [
|
||||
* 23
|
||||
|
@ -111,78 +106,4 @@ public record InferenceTextEmbeddingByteResults(List<InferenceByteEmbedding> emb
|
|||
public int hashCode() {
|
||||
return Objects.hash(embeddings);
|
||||
}
|
||||
|
||||
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
|
||||
public static final String EMBEDDING = "embedding";
|
||||
|
||||
public InferenceByteEmbedding(StreamInput in) throws IOException {
|
||||
this(in.readByteArray());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeByteArray(values);
|
||||
}
|
||||
|
||||
public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
|
||||
byte[] embeddingValues = new byte[embeddingValuesList.size()];
|
||||
for (int i = 0; i < embeddingValuesList.size(); i++) {
|
||||
embeddingValues[i] = embeddingValuesList.get(i);
|
||||
}
|
||||
return new InferenceByteEmbedding(embeddingValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder.startArray(EMBEDDING);
|
||||
for (byte value : values) {
|
||||
builder.value(value);
|
||||
}
|
||||
builder.endArray();
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
|
||||
private float[] toFloatArray() {
|
||||
float[] floatArray = new float[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
floatArray[i] = ((Byte) values[i]).floatValue();
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
|
||||
private double[] toDoubleArray() {
|
||||
double[] doubleArray = new double[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
doubleArray[i] = ((Byte) values[i]).floatValue();
|
||||
}
|
||||
return doubleArray;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSize() {
|
||||
return values().length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
|
||||
return Arrays.equals(values, embedding.values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Arrays.hashCode(values);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingB
|
|||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
|
@ -69,7 +70,7 @@ public class EmbeddingRequestChunker {
|
|||
|
||||
private List<ChunkOffsetsAndInput> chunkedOffsets;
|
||||
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
|
||||
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
|
||||
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
|
||||
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
|
||||
private AtomicArray<Exception> errors;
|
||||
private ActionListener<List<ChunkedInference>> finalListener;
|
||||
|
@ -389,9 +390,9 @@ public class EmbeddingRequestChunker {
|
|||
|
||||
private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
|
||||
ChunkOffsetsAndInput chunks,
|
||||
AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults
|
||||
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
|
||||
) {
|
||||
var all = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
|
||||
var all = new ArrayList<InferenceByteEmbedding>();
|
||||
for (int i = 0; i < debatchedResults.length(); i++) {
|
||||
var subBatch = debatchedResults.get(i);
|
||||
all.addAll(subBatch);
|
||||
|
|
|
@ -17,6 +17,8 @@ import org.elasticsearch.xcontent.XContentFactory;
|
|||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
|
@ -43,7 +45,9 @@ public class CohereEmbeddingsResponseEntity {
|
|||
toLowerCase(CohereEmbeddingType.FLOAT),
|
||||
CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray,
|
||||
toLowerCase(CohereEmbeddingType.INT8),
|
||||
CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray
|
||||
CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray,
|
||||
toLowerCase(CohereEmbeddingType.BINARY),
|
||||
CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray
|
||||
);
|
||||
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
|
||||
|
||||
|
@ -184,17 +188,24 @@ public class CohereEmbeddingsResponseEntity {
|
|||
);
|
||||
}
|
||||
|
||||
private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException {
|
||||
// Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser
|
||||
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
|
||||
|
||||
return new InferenceTextEmbeddingBitResults(embeddingList);
|
||||
}
|
||||
|
||||
private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
|
||||
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
|
||||
|
||||
return new InferenceTextEmbeddingByteResults(embeddingList);
|
||||
}
|
||||
|
||||
private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
|
||||
private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
|
||||
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
|
||||
List<Byte> embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
|
||||
|
||||
return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList);
|
||||
return InferenceByteEmbedding.of(embeddingValuesList);
|
||||
}
|
||||
|
||||
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
|
||||
|
|
|
@ -36,18 +36,29 @@ public enum CohereEmbeddingType {
|
|||
/**
|
||||
* This is a synonym for INT8
|
||||
*/
|
||||
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8);
|
||||
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
|
||||
/**
|
||||
* Use this when you want to get back binary embeddings. Valid only for v3 models.
|
||||
*/
|
||||
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
|
||||
/**
|
||||
* This is a synonym for BIT
|
||||
*/
|
||||
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
|
||||
|
||||
private static final class RequestConstants {
|
||||
private static final String FLOAT = "float";
|
||||
private static final String INT8 = "int8";
|
||||
private static final String BIT = "binary";
|
||||
}
|
||||
|
||||
private static final Map<DenseVectorFieldMapper.ElementType, CohereEmbeddingType> ELEMENT_TYPE_TO_COHERE_EMBEDDING = Map.of(
|
||||
DenseVectorFieldMapper.ElementType.FLOAT,
|
||||
FLOAT,
|
||||
DenseVectorFieldMapper.ElementType.BYTE,
|
||||
BYTE
|
||||
BYTE,
|
||||
DenseVectorFieldMapper.ElementType.BIT,
|
||||
BIT
|
||||
);
|
||||
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
|
||||
ELEMENT_TYPE_TO_COHERE_EMBEDDING.keySet()
|
||||
|
@ -116,6 +127,10 @@ public enum CohereEmbeddingType {
|
|||
return INT8;
|
||||
}
|
||||
|
||||
if (version.before(TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED) && embeddingType == BIT) {
|
||||
return INT8;
|
||||
}
|
||||
|
||||
return embeddingType;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingB
|
|||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
|
@ -368,16 +369,16 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
|
||||
// 4 inputs in 2 batches
|
||||
{
|
||||
var embeddings = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
|
||||
var embeddings = new ArrayList<InferenceByteEmbedding>();
|
||||
for (int i = 0; i < batchSize; i++) {
|
||||
embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() }));
|
||||
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
|
||||
}
|
||||
batches.get(0).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings));
|
||||
}
|
||||
{
|
||||
var embeddings = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
|
||||
var embeddings = new ArrayList<InferenceByteEmbedding>();
|
||||
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
|
||||
embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() }));
|
||||
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
|
||||
}
|
||||
batches.get(1).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings));
|
||||
}
|
||||
|
|
|
@ -72,6 +72,38 @@ public class CohereEmbeddingsRequestEntityTests extends ESTestCase {
|
|||
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException {
|
||||
var entity = new CohereEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
|
||||
"model",
|
||||
CohereEmbeddingType.BINARY
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException {
|
||||
var entity = new CohereEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
|
||||
"model",
|
||||
CohereEmbeddingType.BIT
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
|
||||
var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null);
|
||||
|
||||
|
|
|
@ -145,6 +145,53 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
CohereEmbeddingsModelTests.createModel(
|
||||
"url",
|
||||
"secret",
|
||||
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END),
|
||||
null,
|
||||
null,
|
||||
"model",
|
||||
CohereEmbeddingType.BIT
|
||||
)
|
||||
);
|
||||
|
||||
var httpRequest = request.createHttpRequest();
|
||||
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
|
||||
is(CohereUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(
|
||||
Map.of(
|
||||
"texts",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model",
|
||||
"input_type",
|
||||
"search_query",
|
||||
"embedding_types",
|
||||
List.of("binary"),
|
||||
"truncate",
|
||||
"end"
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testCreateRequest_TruncateNone() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
|
|
|
@ -10,6 +10,8 @@ package org.elasticsearch.xpack.inference.external.response.cohere;
|
|||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
|
@ -182,10 +184,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
parsedResults.embeddings(),
|
||||
is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
|
||||
);
|
||||
MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
|
||||
}
|
||||
|
||||
public void testFromResponse_ParsesBytes() throws IOException {
|
||||
|
@ -220,9 +219,47 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
|
||||
}
|
||||
|
||||
public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOException {
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
||||
"texts": [
|
||||
"hello"
|
||||
],
|
||||
"embeddings": {
|
||||
"binary": [
|
||||
[
|
||||
-55,
|
||||
74,
|
||||
101,
|
||||
67,
|
||||
83
|
||||
]
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "2"
|
||||
},
|
||||
"billed_units": {
|
||||
"input_tokens": 1
|
||||
}
|
||||
},
|
||||
"response_type": "embeddings_by_type"
|
||||
}
|
||||
""";
|
||||
|
||||
InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
|
||||
mock(Request.class),
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
parsedResults.embeddings(),
|
||||
is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
|
||||
is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -318,6 +355,59 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary() throws IOException {
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
||||
"texts": [
|
||||
"hello",
|
||||
"goodbye"
|
||||
],
|
||||
"embeddings": {
|
||||
"binary": [
|
||||
[
|
||||
-55,
|
||||
74,
|
||||
101,
|
||||
67
|
||||
],
|
||||
[
|
||||
34,
|
||||
-64,
|
||||
97,
|
||||
65,
|
||||
-42
|
||||
]
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "2"
|
||||
},
|
||||
"billed_units": {
|
||||
"input_tokens": 1
|
||||
}
|
||||
},
|
||||
"response_type": "embeddings_by_type"
|
||||
}
|
||||
""";
|
||||
|
||||
InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
|
||||
mock(Request.class),
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
parsedResults.embeddings(),
|
||||
is(
|
||||
List.of(
|
||||
new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
|
||||
new InferenceByteEmbedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() {
|
||||
String responseJson = """
|
||||
{
|
||||
|
@ -433,6 +523,82 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|||
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Negative() {
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
||||
"texts": [
|
||||
"hello"
|
||||
],
|
||||
"embeddings": {
|
||||
"binary": [
|
||||
[
|
||||
-129,
|
||||
127
|
||||
]
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "2"
|
||||
},
|
||||
"billed_units": {
|
||||
"input_tokens": 1
|
||||
}
|
||||
},
|
||||
"response_type": "embeddings_by_type"
|
||||
}
|
||||
""";
|
||||
|
||||
var thrownException = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> CohereEmbeddingsResponseEntity.fromResponse(
|
||||
mock(Request.class),
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte"));
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Positive() {
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
||||
"texts": [
|
||||
"hello"
|
||||
],
|
||||
"embeddings": {
|
||||
"binary": [
|
||||
[
|
||||
-128,
|
||||
128
|
||||
]
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "2"
|
||||
},
|
||||
"billed_units": {
|
||||
"input_tokens": 1
|
||||
}
|
||||
},
|
||||
"response_type": "embeddings_by_type"
|
||||
}
|
||||
""";
|
||||
|
||||
var thrownException = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> CohereEmbeddingsResponseEntity.fromResponse(
|
||||
mock(Request.class),
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsToFindAValidEmbeddingType() {
|
||||
String responseJson = """
|
||||
{
|
||||
|
@ -470,7 +636,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|||
|
||||
MatcherAssert.assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]")
|
||||
is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [binary, float, int8]")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.test.rest.FakeRestRequest;
|
|||
import org.elasticsearch.test.rest.RestActionTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -142,9 +143,7 @@ public class BaseInferenceActionTests extends RestActionTestCase {
|
|||
|
||||
static InferenceAction.Response createResponse() {
|
||||
return new InferenceAction.Response(
|
||||
new InferenceTextEmbeddingByteResults(
|
||||
List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 }))
|
||||
)
|
||||
new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1 })))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class InferenceTextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase<InferenceTextEmbeddingBitResults> {
|
||||
public static InferenceTextEmbeddingBitResults createRandomResults() {
|
||||
int embeddings = randomIntBetween(1, 10);
|
||||
List<InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
|
||||
|
||||
for (int i = 0; i < embeddings; i++) {
|
||||
embeddingResults.add(createRandomEmbedding());
|
||||
}
|
||||
|
||||
return new InferenceTextEmbeddingBitResults(embeddingResults);
|
||||
}
|
||||
|
||||
private static InferenceByteEmbedding createRandomEmbedding() {
|
||||
int columns = randomIntBetween(1, 10);
|
||||
byte[] bytes = new byte[columns];
|
||||
|
||||
for (int i = 0; i < columns; i++) {
|
||||
bytes[i] = randomByte();
|
||||
}
|
||||
|
||||
return new InferenceByteEmbedding(bytes);
|
||||
}
|
||||
|
||||
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
|
||||
var entity = new InferenceTextEmbeddingBitResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 })));
|
||||
|
||||
String xContentResult = Strings.toString(entity, true, true);
|
||||
assertThat(xContentResult, is("""
|
||||
{
|
||||
"text_embedding_bits" : [
|
||||
{
|
||||
"embedding" : [
|
||||
23
|
||||
]
|
||||
}
|
||||
]
|
||||
}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
|
||||
var entity = new InferenceTextEmbeddingBitResults(
|
||||
List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 }))
|
||||
);
|
||||
|
||||
String xContentResult = Strings.toString(entity, true, true);
|
||||
assertThat(xContentResult, is("""
|
||||
{
|
||||
"text_embedding_bits" : [
|
||||
{
|
||||
"embedding" : [
|
||||
23
|
||||
]
|
||||
},
|
||||
{
|
||||
"embedding" : [
|
||||
24
|
||||
]
|
||||
}
|
||||
]
|
||||
}"""));
|
||||
}
|
||||
|
||||
public void testTransformToCoordinationFormat() {
|
||||
var results = new InferenceTextEmbeddingBitResults(
|
||||
List.of(
|
||||
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
|
||||
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
|
||||
)
|
||||
).transformToCoordinationFormat();
|
||||
|
||||
assertThat(
|
||||
results,
|
||||
is(
|
||||
List.of(
|
||||
new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
|
||||
new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<InferenceTextEmbeddingBitResults> instanceReader() {
|
||||
return InferenceTextEmbeddingBitResults::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferenceTextEmbeddingBitResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferenceTextEmbeddingBitResults mutateInstance(InferenceTextEmbeddingBitResults instance) throws IOException {
|
||||
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
|
||||
if (randomBoolean()) {
|
||||
// -1 to remove at least one item from the list
|
||||
int end = randomInt(instance.embeddings().size() - 1);
|
||||
return new InferenceTextEmbeddingBitResults(instance.embeddings().subList(0, end));
|
||||
} else {
|
||||
List<InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
|
||||
embeddings.add(createRandomEmbedding());
|
||||
return new InferenceTextEmbeddingBitResults(embeddings);
|
||||
}
|
||||
}
|
||||
|
||||
public static Map<String, Object> buildExpectationByte(List<List<Byte>> embeddings) {
|
||||
return Map.of(
|
||||
InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS,
|
||||
embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList()
|
||||
);
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.results;
|
|||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
|
@ -23,7 +24,7 @@ import static org.hamcrest.Matchers.is;
|
|||
public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<InferenceTextEmbeddingByteResults> {
|
||||
public static InferenceTextEmbeddingByteResults createRandomResults() {
|
||||
int embeddings = randomIntBetween(1, 10);
|
||||
List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
|
||||
List<InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
|
||||
|
||||
for (int i = 0; i < embeddings; i++) {
|
||||
embeddingResults.add(createRandomEmbedding());
|
||||
|
@ -32,7 +33,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
|
|||
return new InferenceTextEmbeddingByteResults(embeddingResults);
|
||||
}
|
||||
|
||||
private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding createRandomEmbedding() {
|
||||
private static InferenceByteEmbedding createRandomEmbedding() {
|
||||
int columns = randomIntBetween(1, 10);
|
||||
byte[] bytes = new byte[columns];
|
||||
|
||||
|
@ -40,13 +41,11 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
|
|||
bytes[i] = randomByte();
|
||||
}
|
||||
|
||||
return new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(bytes);
|
||||
return new InferenceByteEmbedding(bytes);
|
||||
}
|
||||
|
||||
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
|
||||
var entity = new InferenceTextEmbeddingByteResults(
|
||||
List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }))
|
||||
);
|
||||
var entity = new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 })));
|
||||
|
||||
String xContentResult = Strings.toString(entity, true, true);
|
||||
assertThat(xContentResult, is("""
|
||||
|
@ -63,10 +62,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
|
|||
|
||||
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
|
||||
var entity = new InferenceTextEmbeddingByteResults(
|
||||
List.of(
|
||||
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }),
|
||||
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 24 })
|
||||
)
|
||||
List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 }))
|
||||
);
|
||||
|
||||
String xContentResult = Strings.toString(entity, true, true);
|
||||
|
@ -90,8 +86,8 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
|
|||
public void testTransformToCoordinationFormat() {
|
||||
var results = new InferenceTextEmbeddingByteResults(
|
||||
List.of(
|
||||
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
|
||||
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
|
||||
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
|
||||
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
|
||||
)
|
||||
).transformToCoordinationFormat();
|
||||
|
||||
|
@ -124,7 +120,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
|
|||
int end = randomInt(instance.embeddings().size() - 1);
|
||||
return new InferenceTextEmbeddingByteResults(instance.embeddings().subList(0, end));
|
||||
} else {
|
||||
List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
|
||||
List<InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
|
||||
embeddings.add(createRandomEmbedding());
|
||||
return new InferenceTextEmbeddingByteResults(embeddings);
|
||||
}
|
||||
|
@ -133,9 +129,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
|
|||
public static Map<String, Object> buildExpectationByte(List<List<Byte>> embeddings) {
|
||||
return Map.of(
|
||||
InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
|
||||
embeddings.stream()
|
||||
.map(embedding -> Map.of(InferenceTextEmbeddingByteResults.InferenceByteEmbedding.EMBEDDING, embedding))
|
||||
.toList()
|
||||
embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.results;
|
|||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
@ -141,7 +142,7 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<I
|
|||
public static Map<String, Object> buildExpectationByte(List<byte[]> embeddings) {
|
||||
return Map.of(
|
||||
InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
|
||||
embeddings.stream().map(InferenceTextEmbeddingByteResults.InferenceByteEmbedding::new).toList()
|
||||
embeddings.stream().map(InferenceByteEmbedding::new).toList()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -50,6 +50,27 @@ public class CohereEmbeddingTypeTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testTranslateToVersion_ReturnsInt8_WhenVersionIsBeforeBitEnumAddition_WhenSpecifyingBit() {
|
||||
assertThat(
|
||||
CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, new TransportVersion(8_840_0_00)),
|
||||
is(CohereEmbeddingType.INT8)
|
||||
);
|
||||
}
|
||||
|
||||
public void testTranslateToVersion_ReturnsBit_WhenVersionOnBitEnumAddition_WhenSpecifyingBit() {
|
||||
assertThat(
|
||||
CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED),
|
||||
is(CohereEmbeddingType.BIT)
|
||||
);
|
||||
}
|
||||
|
||||
public void testTranslateToVersion_ReturnsFloat_WhenVersionOnBitEnumAddition_WhenSpecifyingFloat() {
|
||||
assertThat(
|
||||
CohereEmbeddingType.translateToVersion(CohereEmbeddingType.FLOAT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED),
|
||||
is(CohereEmbeddingType.FLOAT)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromElementType_CovertsFloatToCohereEmbeddingTypeFloat() {
|
||||
assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.FLOAT), is(CohereEmbeddingType.FLOAT));
|
||||
}
|
||||
|
@ -57,4 +78,8 @@ public class CohereEmbeddingTypeTests extends ESTestCase {
|
|||
public void testFromElementType_CovertsByteToCohereEmbeddingTypeByte() {
|
||||
assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BYTE), is(CohereEmbeddingType.BYTE));
|
||||
}
|
||||
|
||||
public void testFromElementType_ConvertsBitToCohereEmbeddingTypeBinary() {
|
||||
assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BIT), is(CohereEmbeddingType.BIT));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -218,7 +218,7 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
|
|||
is(
|
||||
Strings.format(
|
||||
"Validation Failed: 1: [service_settings] Invalid value [abc] received. "
|
||||
+ "[embedding_type] must be one of [byte, float, int8];"
|
||||
+ "[embedding_type] must be one of [binary, bit, byte, float, int8];"
|
||||
)
|
||||
)
|
||||
);
|
||||
|
@ -238,7 +238,7 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
|
|||
is(
|
||||
Strings.format(
|
||||
"Validation Failed: 1: [service_settings] Invalid value [abc] received. "
|
||||
+ "[embedding_type] must be one of [byte, float];"
|
||||
+ "[embedding_type] must be one of [bit, byte, float];"
|
||||
)
|
||||
)
|
||||
);
|
||||
|
@ -289,6 +289,16 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
|
|||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ConvertsBit_ToCohereEmbeddingTypeBit() {
|
||||
assertThat(
|
||||
CohereEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.BIT.toString())),
|
||||
ConfigurationParseContext.REQUEST
|
||||
),
|
||||
is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BIT))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_PreservesEmbeddingTypeFloat() {
|
||||
assertThat(
|
||||
CohereEmbeddingsServiceSettings.fromMap(
|
||||
|
@ -314,6 +324,8 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
|
|||
assertEquals(CohereEmbeddingType.BYTE, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("byte", validation));
|
||||
assertEquals(CohereEmbeddingType.INT8, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("int8", validation));
|
||||
assertEquals(CohereEmbeddingType.FLOAT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("float", validation));
|
||||
assertEquals(CohereEmbeddingType.BINARY, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("binary", validation));
|
||||
assertEquals(CohereEmbeddingType.BIT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("bit", validation));
|
||||
assertTrue(validation.validationErrors().isEmpty());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue