Add basic implementations of float-byte script comparisons (#122381)

Add implementations of `cosineSimilarity` and `dotProduct` to query byte vector fields using float vectors
This commit is contained in:
Simon Cooper 2025-03-03 09:38:37 +00:00 committed by GitHub
parent 5697f7f016
commit 82668b40f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 325 additions and 169 deletions

View File

@ -0,0 +1,6 @@
pr: 122381
summary: Adds implementations of dotProduct and cosineSimilarity painless methods to operate on float vectors for byte fields
area: Vector Search
type: enhancement
issues:
- 117274

View File

@ -80,6 +80,19 @@ public class ESVectorUtil {
return IMPL.ipFloatBit(q, d);
}
/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a byte vector.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static float ipFloatByte(float[] q, byte[] d) {
if (q.length != d.length) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
}
return IMPL.ipFloatByte(q, d);
}
/**
* AND bit count computed over signed bytes.
* Copied from Lucene's XOR implementation

View File

@ -39,6 +39,11 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
return ipFloatBitImpl(q, d);
}
@Override
public float ipFloatByte(float[] q, byte[] d) {
return ipFloatByteImpl(q, d);
}
public static int ipByteBitImpl(byte[] q, byte[] d) {
assert q.length == d.length * Byte.SIZE;
int acc0 = 0;
@ -101,4 +106,12 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
}
return ret;
}
public static float ipFloatByteImpl(float[] q, byte[] d) {
float ret = 0;
for (int i = 0; i < q.length; i++) {
ret += q[i] * d[i];
}
return ret;
}
}

View File

@ -18,4 +18,6 @@ public interface ESVectorUtilSupport {
int ipByteBit(byte[] q, byte[] d);
float ipFloatBit(float[] q, byte[] d);
float ipFloatByte(float[] q, byte[] d);
}

View File

@ -58,6 +58,11 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
}
@Override
public float ipFloatByte(float[] q, byte[] d) {
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
}
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;

View File

@ -32,11 +32,28 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
public void testIpFloatBit() {
float[] q = new float[16];
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
random().nextFloat();
for (int i = 0; i < q.length; i++) {
q[i] = random().nextFloat();
}
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
}
public void testIpFloatByte() {
float[] q = new float[16];
byte[] d = new byte[16];
for (int i = 0; i < q.length; i++) {
q[i] = random().nextFloat();
}
random().nextBytes(d);
float expected = 0;
for (int i = 0; i < q.length; i++) {
expected += q[i] * d[i];
}
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
}
public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}

View File

@ -107,6 +107,38 @@ setup:
- match: {hits.hits.2._id: "1"}
- match: {hits.hits.2._score: 1632.0}
---
"Dot Product float":
- requires:
capabilities:
- path: /_search
capabilities: [byte_float_dot_product_capability]
test_runner_features: [capabilities]
reason: "float vector queries capability added"
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 32865.2}
- match: {hits.hits.1._id: "3"}
- match: {hits.hits.1._score: 21413.4}
- match: {hits.hits.2._id: "1"}
- match: {hits.hits.2._score: 1862.3}
---
"Cosine Similarity":
- do:
headers:
@ -198,3 +230,39 @@ setup:
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.509}
- lte: {hits.hits.2._score: 0.512}
---
"Cosine Similarity float":
- requires:
capabilities:
- path: /_search
capabilities: [byte_float_dot_product_capability]
test_runner_features: [capabilities]
reason: "float vector queries capability added"
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- gte: {hits.hits.0._score: 0.989}
- lte: {hits.hits.0._score: 0.992}
- match: {hits.hits.1._id: "3"}
- gte: {hits.hits.1._score: 0.885}
- lte: {hits.hits.1._score: 0.888}
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.505}
- lte: {hits.hits.2._score: 0.508}

View File

@ -346,16 +346,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);
StringBuilder errorBuilder = null;
StringBuilder checkVectorErrors(float[] vector) {
StringBuilder errors = checkNanAndInfinite(vector);
if (errors != null) {
return errors;
}
for (int index = 0; index < vector.length; ++index) {
float value = vector[index];
if (value % 1.0f != 0.0f) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
@ -368,7 +369,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support integers between ["
@ -385,9 +386,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
}
if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
return errors;
}
@Override
@ -614,8 +613,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);
StringBuilder checkVectorErrors(float[] vector) {
return checkNanAndInfinite(vector);
}
@Override
@ -768,16 +767,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);
StringBuilder errorBuilder = null;
StringBuilder checkVectorErrors(float[] vector) {
StringBuilder errors = checkNanAndInfinite(vector);
if (errors != null) {
return errors;
}
for (int index = 0; index < vector.length; ++index) {
float value = vector[index];
if (value % 1.0f != 0.0f) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
@ -790,7 +790,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support integers between ["
@ -807,9 +807,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
}
if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
return errors;
}
@Override
@ -993,7 +991,44 @@ public class DenseVectorFieldMapper extends FieldMapper {
public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
public abstract void checkVectorBounds(float[] vector);
/**
* Checks the input {@code vector} is one of the {@code possibleTypes},
* and returns the first type that it matches
*/
public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
assert possibleTypes.length != 0;
// we're looking for one valid allowed type
// assume the types are in order of specificity
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
for (int i = 0; i < possibleTypes.length; i++) {
StringBuilder error = possibleTypes[i].checkVectorErrors(vector);
if (error == null) {
// this one works - use it
return possibleTypes[i];
} else {
errors[i] = error;
}
}
// oh dear, none of the possible types work with this vector. Generate the error message and throw.
StringBuilder message = new StringBuilder();
for (int i = 0; i < possibleTypes.length; i++) {
if (i > 0) {
message.append(" ");
}
message.append("Vector is not a ").append(possibleTypes[i]).append(" vector: ").append(errors[i]);
}
throw new IllegalArgumentException(appendErrorElements(message, vector).toString());
}
public void checkVectorBounds(float[] vector) {
StringBuilder errors = checkVectorErrors(vector);
if (errors != null) {
throw new IllegalArgumentException(appendErrorElements(errors, vector).toString());
}
}
abstract StringBuilder checkVectorErrors(float[] vector);
abstract void checkVectorMagnitude(
VectorSimilarity similarity,
@ -1017,7 +1052,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
return index;
}
void checkNanAndInfinite(float[] vector) {
StringBuilder checkNanAndInfinite(float[] vector) {
StringBuilder errorBuilder = null;
for (int index = 0; index < vector.length; ++index) {
@ -1044,9 +1079,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
}
if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
return errorBuilder;
}
static StringBuilder appendErrorElements(StringBuilder errorBuilder, float[] vector) {

View File

@ -25,6 +25,8 @@ public final class SearchCapabilities {
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
/** Support float query vectors on byte vectors */
private static final String BYTE_FLOAT_DOT_PRODUCT_CAPABILITY = "byte_float_dot_product_capability";
/** Support docvalue_fields parameter for `dense_vector` field. */
private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
/** Support transforming rank rrf queries to the corresponding rrf retriever. */
@ -50,6 +52,7 @@ public final class SearchCapabilities {
capabilities.add(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY);
capabilities.add(BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY);
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
capabilities.add(BYTE_FLOAT_DOT_PRODUCT_CAPABILITY);
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);

View File

@ -11,6 +11,7 @@ package org.elasticsearch.script;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
@ -42,7 +43,10 @@ public class VectorScoreScriptUtils {
}
public static class ByteDenseVectorFunction extends DenseVectorFunction {
protected final byte[] queryVector;
// either byteQueryVector or floatQueryVector will be non-null
protected final byte[] byteQueryVector;
protected final float[] floatQueryVector;
// only valid if byteQueryVector is used
protected final float qvMagnitude;
/**
@ -51,22 +55,51 @@ public class VectorScoreScriptUtils {
* @param scoreScript The script in which this function was referenced.
* @param field The vector field.
* @param queryVector The query vector.
* @param normalizeFloatQuery {@code true} if the query vector is a float vector, then normalize it.
* @param allowedTypes The types the vector is allowed to be.
*/
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
public ByteDenseVectorFunction(
ScoreScript scoreScript,
DenseVectorDocValuesField field,
List<Number> queryVector,
boolean normalizeFloatQuery,
ElementType... allowedTypes
) {
super(scoreScript, field);
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
this.queryVector = new byte[queryVector.size()];
float[] validateValues = new float[queryVector.size()];
int queryMagnitude = 0;
float[] floatValues = new float[queryVector.size()];
double queryMagnitude = 0;
for (int i = 0; i < queryVector.size(); i++) {
final Number number = queryVector.get(i);
byte value = number.byteValue();
this.queryVector[i] = value;
float value = queryVector.get(i).floatValue();
floatValues[i] = value;
queryMagnitude += value * value;
validateValues[i] = number.floatValue();
}
this.qvMagnitude = (float) Math.sqrt(queryMagnitude);
field.getElementType().checkVectorBounds(validateValues);
queryMagnitude = Math.sqrt(queryMagnitude);
switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
case FLOAT:
byteQueryVector = null;
floatQueryVector = floatValues;
qvMagnitude = -1; // invalid valid, not used for float vectors
if (normalizeFloatQuery) {
for (int i = 0; i < floatQueryVector.length; i++) {
floatQueryVector[i] /= (float) queryMagnitude;
}
}
break;
case BYTE:
floatQueryVector = null;
byteQueryVector = new byte[floatValues.length];
for (int i = 0; i < floatValues.length; i++) {
byteQueryVector[i] = (byte) floatValues[i];
}
this.qvMagnitude = (float) queryMagnitude;
break;
default:
throw new AssertionError("Unexpected element type");
}
}
/**
@ -78,8 +111,9 @@ public class VectorScoreScriptUtils {
*/
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
super(scoreScript, field);
this.queryVector = queryVector;
float queryMagnitude = 0.0f;
byteQueryVector = queryVector;
floatQueryVector = null;
double queryMagnitude = 0.0f;
for (byte value : queryVector) {
queryMagnitude += value * value;
}
@ -133,7 +167,7 @@ public class VectorScoreScriptUtils {
public static class ByteL1Norm extends ByteDenseVectorFunction implements L1NormInterface {
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
super(scoreScript, field, queryVector, false, ElementType.BYTE);
}
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@ -142,7 +176,7 @@ public class VectorScoreScriptUtils {
public double l1norm() {
setNextVector();
return field.get().l1Norm(queryVector);
return field.get().l1Norm(byteQueryVector);
}
}
@ -197,7 +231,7 @@ public class VectorScoreScriptUtils {
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
super(scoreScript, field, queryVector, false, ElementType.BYTE);
}
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@ -206,7 +240,7 @@ public class VectorScoreScriptUtils {
public int hamming() {
setNextVector();
return field.get().hamming(queryVector);
return field.get().hamming(byteQueryVector);
}
}
@ -243,7 +277,7 @@ public class VectorScoreScriptUtils {
public static class ByteL2Norm extends ByteDenseVectorFunction implements L2NormInterface {
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
super(scoreScript, field, queryVector, false, ElementType.BYTE);
}
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@ -252,7 +286,7 @@ public class VectorScoreScriptUtils {
public double l2norm() {
setNextVector();
return field.get().l2Norm(queryVector);
return field.get().l2Norm(byteQueryVector);
}
}
@ -388,7 +422,7 @@ public class VectorScoreScriptUtils {
public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
super(scoreScript, field, queryVector, false, ElementType.BYTE, ElementType.FLOAT);
}
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@ -397,7 +431,11 @@ public class VectorScoreScriptUtils {
public double dotProduct() {
setNextVector();
return field.get().dotProduct(queryVector);
if (floatQueryVector != null) {
return field.get().dotProduct(floatQueryVector);
} else {
return field.get().dotProduct(byteQueryVector);
}
}
}
@ -461,7 +499,7 @@ public class VectorScoreScriptUtils {
public static class ByteCosineSimilarity extends ByteDenseVectorFunction implements CosineSimilarityInterface {
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
super(scoreScript, field, queryVector, true, ElementType.BYTE, ElementType.FLOAT);
}
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@ -470,7 +508,12 @@ public class VectorScoreScriptUtils {
public double cosineSimilarity() {
setNextVector();
return field.get().cosineSimilarity(queryVector, qvMagnitude);
if (floatQueryVector != null) {
// float vector is already normalized by the superclass constructor
return field.get().cosineSimilarity(floatQueryVector, false);
} else {
return field.get().cosineSimilarity(byteQueryVector, qvMagnitude);
}
}
}

View File

@ -12,6 +12,7 @@ package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.nio.ByteBuffer;
import java.util.List;
@ -61,7 +62,7 @@ public class ByteBinaryDenseVector implements DenseVector {
@Override
public double dotProduct(float[] queryVector) {
throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
return ESVectorUtil.ipFloatByte(queryVector, vectorValue);
}
@Override
@ -142,7 +143,11 @@ public class ByteBinaryDenseVector implements DenseVector {
@Override
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
if (normalizeQueryVector) {
return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
}
return dotProduct(queryVector) / getMagnitude();
}
@Override

View File

@ -11,6 +11,7 @@ package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.util.List;
@ -51,12 +52,12 @@ public class ByteKnnDenseVector implements DenseVector {
@Override
public int dotProduct(byte[] queryVector) {
return VectorUtil.dotProduct(docVector, queryVector);
return VectorUtil.dotProduct(queryVector, docVector);
}
@Override
public double dotProduct(float[] queryVector) {
throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
return ESVectorUtil.ipFloatByte(queryVector, docVector);
}
@Override
@ -145,7 +146,11 @@ public class ByteKnnDenseVector implements DenseVector {
@Override
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
if (normalizeQueryVector) {
return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
}
return dotProduct(queryVector) / getMagnitude();
}
@Override

View File

@ -29,7 +29,6 @@ import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
@ -43,8 +42,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
String fieldName = "vector";
int dims = 5;
float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f };
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
List<Number> queryVector = List.of(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
List<Number> invalidQueryVector = List.of(0.5, 111.3);
List<DenseVectorDocValuesField> fields = List.of(
new BinaryDenseVectorDocValuesField(
@ -141,8 +140,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
String fieldName = "vector";
int dims = 5;
float[] docVector = new float[] { 1, 127, -128, 5, -10 };
List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
String hexidecimalString = HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 });
List<DenseVectorDocValuesField> fields = List.of(
@ -183,11 +182,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
for (int i = 0; i < queryVectorArray.length; i++) {
queryVectorArray[i] = queryVector.get(i).floatValue();
}
UnsupportedOperationException uoe = expectThrows(
UnsupportedOperationException.class,
() -> field.getInternal().cosineSimilarity(queryVectorArray, true)
assertEquals(
"cosineSimilarity result is not equal to the expected value!",
cosineSimilarityExpected,
field.getInternal().cosineSimilarity(queryVectorArray, true),
0.001
);
assertThat(uoe.getMessage(), containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead"));
// Check each function rejects query vectors with the wrong dimension
IllegalArgumentException e = expectThrows(
@ -240,9 +240,9 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
int dims = 8;
float[] docVector = new float[] { 124 };
// 124 in binary is b01111100
List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
List<Number> floatQueryVector = Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
List<Number> floatQueryVector = List.of(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
String hexidecimalString = HexFormat.of().formatHex(new byte[] { 124 });
List<DenseVectorDocValuesField> fields = List.of(
@ -293,8 +293,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
public void testByteVsFloatSimilarity() throws IOException {
int dims = 5;
float[] docVector = new float[] { 1f, 127f, -128f, 5f, -10f };
List<Number> listFloatVector = Arrays.asList(1f, 125f, -12f, 2f, 4f);
List<Number> listByteVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
List<Number> listFloatVector = List.of(1f, 125f, -12f, 2f, 4f);
List<Number> listByteVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
float[] floatVector = new float[] { 1f, 125f, -12f, 2f, 4f };
byte[] byteVector = new byte[] { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 };
@ -342,11 +342,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
switch (field.getElementType()) {
case BYTE -> {
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(byteVector));
UnsupportedOperationException e = expectThrows(
UnsupportedOperationException.class,
() -> field.get().dotProduct(floatVector)
);
assertThat(e.getMessage(), containsString("use [int dotProduct(byte[] queryVector)] instead"));
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
}
case FLOAT -> {
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
@ -423,14 +419,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
switch (field.getElementType()) {
case BYTE -> {
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(byteVector), 0.001);
UnsupportedOperationException e = expectThrows(
UnsupportedOperationException.class,
() -> field.get().cosineSimilarity(floatVector)
);
assertThat(
e.getMessage(),
containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead")
);
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
}
case FLOAT -> {
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
@ -471,81 +460,55 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
IllegalArgumentException e;
e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, greaterThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]"
expectThrows(
IllegalArgumentException.class,
containsString(
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]"
),
() -> new L1Norm(scoreScript, greaterThanVector, fieldName)
);
e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, greaterThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]"
);
e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, greaterThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]"
);
e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, greaterThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]"
expectThrows(
IllegalArgumentException.class,
containsString(
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]"
),
() -> new L2Norm(scoreScript, greaterThanVector, fieldName)
);
e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, lessThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
expectThrows(
IllegalArgumentException.class,
containsString(
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
),
() -> new L1Norm(scoreScript, lessThanVector, fieldName)
);
e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, lessThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
);
e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, lessThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
);
e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, lessThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
expectThrows(
IllegalArgumentException.class,
containsString(
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
),
() -> new L2Norm(scoreScript, lessThanVector, fieldName)
);
e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, decimalVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
expectThrows(
IllegalArgumentException.class,
containsString(
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
),
() -> new L1Norm(scoreScript, decimalVector, fieldName)
);
e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, decimalVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
);
e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, decimalVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
);
e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, decimalVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
expectThrows(
IllegalArgumentException.class,
containsString(
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
),
() -> new L2Norm(scoreScript, decimalVector, fieldName)
);
}
}

View File

@ -149,11 +149,6 @@ public class DenseVectorTests extends ESTestCase {
ByteKnnDenseVector knn = new ByteKnnDenseVector(docVector);
UnsupportedOperationException e;
e = expectThrows(UnsupportedOperationException.class, () -> knn.dotProduct(queryVector));
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> knn.dotProduct((Object) queryVector));
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> knn.l1Norm(queryVector));
assertEquals(e.getMessage(), "use [int l1Norm(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> knn.l1Norm((Object) queryVector));
@ -164,18 +159,8 @@ public class DenseVectorTests extends ESTestCase {
e = expectThrows(UnsupportedOperationException.class, () -> knn.l2Norm((Object) queryVector));
assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity(queryVector));
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity((Object) queryVector));
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
ByteBinaryDenseVector binary = new ByteBinaryDenseVector(docVector, new BytesRef(docVector), dims);
e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct(queryVector));
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct((Object) queryVector));
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> binary.l1Norm(queryVector));
assertEquals(e.getMessage(), "use [int l1Norm(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> binary.l1Norm((Object) queryVector));
@ -185,11 +170,6 @@ public class DenseVectorTests extends ESTestCase {
assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> binary.l2Norm((Object) queryVector));
assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> binary.cosineSimilarity(queryVector));
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
e = expectThrows(UnsupportedOperationException.class, () -> binary.cosineSimilarity((Object) queryVector));
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
}
public void testFloatUnsupported() {