Adding support for binary embedding type to Cohere service embedding type (#120751)

* Adding support for binary embedding type to Cohere service embedding type

* Returning response in separate text_embedding_bits field

* Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java

Co-authored-by: David Kyle <david.kyle@elastic.co>

* Update docs/changelog/120751.yaml

* Reverting docs change

---------

Co-authored-by: David Kyle <david.kyle@elastic.co>
This commit is contained in:
Ying Mao 2025-02-03 13:55:31 -05:00 committed by GitHub
parent 92d1d31eea
commit 89d71e1f6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 690 additions and 121 deletions

View File

@ -0,0 +1,5 @@
pr: 120751
summary: Adding support for binary embedding type to Cohere service embedding type
area: Machine Learning
type: enhancement
issues: []

View File

@ -173,7 +173,7 @@ public class TransportVersions {
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00);
public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View File

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";
public InferenceByteEmbedding(StreamInput in) throws IOException {
this(in.readByteArray());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(values);
}
public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
byte[] embeddingValues = new byte[embeddingValuesList.size()];
for (int i = 0; i < embeddingValuesList.size(); i++) {
embeddingValues[i] = embeddingValuesList.get(i);
}
return new InferenceByteEmbedding(embeddingValues);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray(EMBEDDING);
for (byte value : values) {
builder.value(value);
}
builder.endArray();
builder.endObject();
return builder;
}
@Override
public String toString() {
return Strings.toString(this);
}
float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}
double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).doubleValue();
}
return doubleArray;
}
@Override
public int getSize() {
return values().length;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
return Arrays.equals(values, embedding.values);
}
@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}

View File

@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding_bytes": [
* {
* "embedding": [
* 23
* ]
* },
* {
* "embedding": [
* -23
* ]
* }
* ]
* }
*/
public record InferenceTextEmbeddingBitResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(InferenceByteEmbedding::new));
}
@Override
public int getFirstEmbeddingSize() {
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
}
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(embeddings);
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.toList();
}
@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
);
return List.of(legacyEmbedding);
}
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING_BITS, embeddings);
return map;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o;
return Objects.equals(embeddings, that.embeddings);
}
@Override
public int hashCode() {
return Objects.hash(embeddings);
}
}

View File

@ -9,21 +9,16 @@
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
@ -33,7 +28,7 @@ import java.util.Objects;
/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding": [
* "text_embedding_bytes": [
* {
* "embedding": [
* 23
@ -111,78 +106,4 @@ public record InferenceTextEmbeddingByteResults(List<InferenceByteEmbedding> emb
public int hashCode() {
return Objects.hash(embeddings);
}
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";
public InferenceByteEmbedding(StreamInput in) throws IOException {
this(in.readByteArray());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(values);
}
public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
byte[] embeddingValues = new byte[embeddingValuesList.size()];
for (int i = 0; i < embeddingValuesList.size(); i++) {
embeddingValues[i] = embeddingValuesList.get(i);
}
return new InferenceByteEmbedding(embeddingValues);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray(EMBEDDING);
for (byte value : values) {
builder.value(value);
}
builder.endArray();
builder.endObject();
return builder;
}
@Override
public String toString() {
return Strings.toString(this);
}
private float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}
private double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).floatValue();
}
return doubleArray;
}
@Override
public int getSize() {
return values().length;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
return Arrays.equals(values, embedding.values);
}
@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}
}

View File

@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingB
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@ -69,7 +70,7 @@ public class EmbeddingRequestChunker {
private List<ChunkOffsetsAndInput> chunkedOffsets;
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
private AtomicArray<Exception> errors;
private ActionListener<List<ChunkedInference>> finalListener;
@ -389,9 +390,9 @@ public class EmbeddingRequestChunker {
private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
ChunkOffsetsAndInput chunks,
AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
) {
var all = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
var all = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < debatchedResults.length(); i++) {
var subBatch = debatchedResults.get(i);
all.addAll(subBatch);

View File

@ -17,6 +17,8 @@ import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
@ -43,7 +45,9 @@ public class CohereEmbeddingsResponseEntity {
toLowerCase(CohereEmbeddingType.FLOAT),
CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray,
toLowerCase(CohereEmbeddingType.INT8),
CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray
CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray,
toLowerCase(CohereEmbeddingType.BINARY),
CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray
);
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
@ -184,17 +188,24 @@ public class CohereEmbeddingsResponseEntity {
);
}
private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException {
// Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
return new InferenceTextEmbeddingBitResults(embeddingList);
}
private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
return new InferenceTextEmbeddingByteResults(embeddingList);
}
private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List<Byte> embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList);
return InferenceByteEmbedding.of(embeddingValuesList);
}
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {

View File

@ -36,18 +36,29 @@ public enum CohereEmbeddingType {
/**
* This is a synonym for INT8
*/
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8);
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
/**
* Use this when you want to get back binary embeddings. Valid only for v3 models.
*/
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
/**
* This is a synonym for BIT
*/
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
private static final class RequestConstants {
private static final String FLOAT = "float";
private static final String INT8 = "int8";
private static final String BIT = "binary";
}
private static final Map<DenseVectorFieldMapper.ElementType, CohereEmbeddingType> ELEMENT_TYPE_TO_COHERE_EMBEDDING = Map.of(
DenseVectorFieldMapper.ElementType.FLOAT,
FLOAT,
DenseVectorFieldMapper.ElementType.BYTE,
BYTE
BYTE,
DenseVectorFieldMapper.ElementType.BIT,
BIT
);
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
ELEMENT_TYPE_TO_COHERE_EMBEDDING.keySet()
@ -116,6 +127,10 @@ public enum CohereEmbeddingType {
return INT8;
}
if (version.before(TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED) && embeddingType == BIT) {
return INT8;
}
return embeddingType;
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingB
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@ -368,16 +369,16 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
// 4 inputs in 2 batches
{
var embeddings = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
var embeddings = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < batchSize; i++) {
embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() }));
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
}
batches.get(0).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings));
}
{
var embeddings = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
var embeddings = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() }));
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
}
batches.get(1).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings));
}

View File

@ -72,6 +72,38 @@ public class CohereEmbeddingsRequestEntityTests extends ESTestCase {
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}"""));
}
public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(
List.of("abc"),
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
"model",
CohereEmbeddingType.BINARY
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}"""));
}
public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(
List.of("abc"),
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
"model",
CohereEmbeddingType.BIT
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}"""));
}
public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null);

View File

@ -145,6 +145,53 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
);
}
public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException {
var request = createRequest(
List.of("abc"),
CohereEmbeddingsModelTests.createModel(
"url",
"secret",
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END),
null,
null,
"model",
CohereEmbeddingType.BIT
)
);
var httpRequest = request.createHttpRequest();
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
is(CohereUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
requestMap,
is(
Map.of(
"texts",
List.of("abc"),
"model",
"model",
"input_type",
"search_query",
"embedding_types",
List.of("binary"),
"truncate",
"end"
)
)
);
}
public void testCreateRequest_TruncateNone() throws IOException {
var request = createRequest(
List.of("abc"),

View File

@ -10,6 +10,8 @@ package org.elasticsearch.xpack.inference.external.response.cohere;
import org.apache.http.HttpResponse;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
@ -182,10 +184,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
);
MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
}
public void testFromResponse_ParsesBytes() throws IOException {
@ -220,9 +219,47 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
}
public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOException {
String responseJson = """
{
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
"texts": [
"hello"
],
"embeddings": {
"binary": [
[
-55,
74,
101,
67,
83
]
]
},
"meta": {
"api_version": {
"version": "2"
},
"billed_units": {
"input_tokens": 1
}
},
"response_type": "embeddings_by_type"
}
""";
InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@ -318,6 +355,59 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
);
}
public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary() throws IOException {
String responseJson = """
{
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
"texts": [
"hello",
"goodbye"
],
"embeddings": {
"binary": [
[
-55,
74,
101,
67
],
[
34,
-64,
97,
65,
-42
]
]
},
"meta": {
"api_version": {
"version": "2"
},
"billed_units": {
"input_tokens": 1
}
},
"response_type": "embeddings_by_type"
}
""";
InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
is(
List.of(
new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
new InferenceByteEmbedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
)
)
);
}
public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() {
String responseJson = """
{
@ -433,6 +523,82 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
}
public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Negative() {
String responseJson = """
{
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
"texts": [
"hello"
],
"embeddings": {
"binary": [
[
-129,
127
]
]
},
"meta": {
"api_version": {
"version": "2"
},
"billed_units": {
"input_tokens": 1
}
},
"response_type": "embeddings_by_type"
}
""";
var thrownException = expectThrows(
IllegalArgumentException.class,
() -> CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte"));
}
public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Positive() {
String responseJson = """
{
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
"texts": [
"hello"
],
"embeddings": {
"binary": [
[
-128,
128
]
]
},
"meta": {
"api_version": {
"version": "2"
},
"billed_units": {
"input_tokens": 1
}
},
"response_type": "embeddings_by_type"
}
""";
var thrownException = expectThrows(
IllegalArgumentException.class,
() -> CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
}
public void testFromResponse_FailsToFindAValidEmbeddingType() {
String responseJson = """
{
@ -470,7 +636,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
MatcherAssert.assertThat(
thrownException.getMessage(),
is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]")
is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [binary, float, int8]")
);
}
}

View File

@ -21,6 +21,7 @@ import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.junit.Before;
@ -142,9 +143,7 @@ public class BaseInferenceActionTests extends RestActionTestCase {
static InferenceAction.Response createResponse() {
return new InferenceAction.Response(
new InferenceTextEmbeddingByteResults(
List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 }))
)
new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1 })))
);
}
}

View File

@ -0,0 +1,135 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.is;
public class InferenceTextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase<InferenceTextEmbeddingBitResults> {
public static InferenceTextEmbeddingBitResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
List<InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
return new InferenceTextEmbeddingBitResults(embeddingResults);
}
private static InferenceByteEmbedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
byte[] bytes = new byte[columns];
for (int i = 0; i < columns; i++) {
bytes[i] = randomByte();
}
return new InferenceByteEmbedding(bytes);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
var entity = new InferenceTextEmbeddingBitResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
{
"text_embedding_bits" : [
{
"embedding" : [
23
]
}
]
}"""));
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
var entity = new InferenceTextEmbeddingBitResults(
List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 }))
);
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
{
"text_embedding_bits" : [
{
"embedding" : [
23
]
},
{
"embedding" : [
24
]
}
]
}"""));
}
public void testTransformToCoordinationFormat() {
var results = new InferenceTextEmbeddingBitResults(
List.of(
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
)
).transformToCoordinationFormat();
assertThat(
results,
is(
List.of(
new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
)
)
);
}
@Override
protected Writeable.Reader<InferenceTextEmbeddingBitResults> instanceReader() {
return InferenceTextEmbeddingBitResults::new;
}
@Override
protected InferenceTextEmbeddingBitResults createTestInstance() {
return createRandomResults();
}
@Override
protected InferenceTextEmbeddingBitResults mutateInstance(InferenceTextEmbeddingBitResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
return new InferenceTextEmbeddingBitResults(instance.embeddings().subList(0, end));
} else {
List<InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
return new InferenceTextEmbeddingBitResults(embeddings);
}
}
public static Map<String, Object> buildExpectationByte(List<List<Byte>> embeddings) {
return Map.of(
InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS,
embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList()
);
}
}

View File

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@ -23,7 +24,7 @@ import static org.hamcrest.Matchers.is;
public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<InferenceTextEmbeddingByteResults> {
public static InferenceTextEmbeddingByteResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
List<InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
@ -32,7 +33,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
return new InferenceTextEmbeddingByteResults(embeddingResults);
}
private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding createRandomEmbedding() {
private static InferenceByteEmbedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
byte[] bytes = new byte[columns];
@ -40,13 +41,11 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
bytes[i] = randomByte();
}
return new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(bytes);
return new InferenceByteEmbedding(bytes);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
var entity = new InferenceTextEmbeddingByteResults(
List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }))
);
var entity = new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@ -63,10 +62,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
var entity = new InferenceTextEmbeddingByteResults(
List.of(
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }),
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 24 })
)
List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 }))
);
String xContentResult = Strings.toString(entity, true, true);
@ -90,8 +86,8 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
public void testTransformToCoordinationFormat() {
var results = new InferenceTextEmbeddingByteResults(
List.of(
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
)
).transformToCoordinationFormat();
@ -124,7 +120,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
int end = randomInt(instance.embeddings().size() - 1);
return new InferenceTextEmbeddingByteResults(instance.embeddings().subList(0, end));
} else {
List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
List<InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
return new InferenceTextEmbeddingByteResults(embeddings);
}
@ -133,9 +129,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
public static Map<String, Object> buildExpectationByte(List<List<Byte>> embeddings) {
return Map.of(
InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
embeddings.stream()
.map(embedding -> Map.of(InferenceTextEmbeddingByteResults.InferenceByteEmbedding.EMBEDDING, embedding))
.toList()
embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList()
);
}
}

View File

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.results;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@ -141,7 +142,7 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<I
public static Map<String, Object> buildExpectationByte(List<byte[]> embeddings) {
return Map.of(
InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
embeddings.stream().map(InferenceTextEmbeddingByteResults.InferenceByteEmbedding::new).toList()
embeddings.stream().map(InferenceByteEmbedding::new).toList()
);
}

View File

@ -50,6 +50,27 @@ public class CohereEmbeddingTypeTests extends ESTestCase {
);
}
public void testTranslateToVersion_ReturnsInt8_WhenVersionIsBeforeBitEnumAddition_WhenSpecifyingBit() {
assertThat(
CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, new TransportVersion(8_840_0_00)),
is(CohereEmbeddingType.INT8)
);
}
public void testTranslateToVersion_ReturnsBit_WhenVersionOnBitEnumAddition_WhenSpecifyingBit() {
assertThat(
CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED),
is(CohereEmbeddingType.BIT)
);
}
public void testTranslateToVersion_ReturnsFloat_WhenVersionOnBitEnumAddition_WhenSpecifyingFloat() {
assertThat(
CohereEmbeddingType.translateToVersion(CohereEmbeddingType.FLOAT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED),
is(CohereEmbeddingType.FLOAT)
);
}
public void testFromElementType_CovertsFloatToCohereEmbeddingTypeFloat() {
assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.FLOAT), is(CohereEmbeddingType.FLOAT));
}
@ -57,4 +78,8 @@ public class CohereEmbeddingTypeTests extends ESTestCase {
public void testFromElementType_CovertsByteToCohereEmbeddingTypeByte() {
assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BYTE), is(CohereEmbeddingType.BYTE));
}
public void testFromElementType_ConvertsBitToCohereEmbeddingTypeBinary() {
assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BIT), is(CohereEmbeddingType.BIT));
}
}

View File

@ -218,7 +218,7 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
is(
Strings.format(
"Validation Failed: 1: [service_settings] Invalid value [abc] received. "
+ "[embedding_type] must be one of [byte, float, int8];"
+ "[embedding_type] must be one of [binary, bit, byte, float, int8];"
)
)
);
@ -238,7 +238,7 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
is(
Strings.format(
"Validation Failed: 1: [service_settings] Invalid value [abc] received. "
+ "[embedding_type] must be one of [byte, float];"
+ "[embedding_type] must be one of [bit, byte, float];"
)
)
);
@ -289,6 +289,16 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
);
}
public void testFromMap_ConvertsBit_ToCohereEmbeddingTypeBit() {
assertThat(
CohereEmbeddingsServiceSettings.fromMap(
new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.BIT.toString())),
ConfigurationParseContext.REQUEST
),
is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BIT))
);
}
public void testFromMap_PreservesEmbeddingTypeFloat() {
assertThat(
CohereEmbeddingsServiceSettings.fromMap(
@ -314,6 +324,8 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
assertEquals(CohereEmbeddingType.BYTE, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("byte", validation));
assertEquals(CohereEmbeddingType.INT8, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("int8", validation));
assertEquals(CohereEmbeddingType.FLOAT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("float", validation));
assertEquals(CohereEmbeddingType.BINARY, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("binary", validation));
assertEquals(CohereEmbeddingType.BIT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("bit", validation));
assertTrue(validation.validationErrors().isEmpty());
}