Add basic implementations of float-byte script comparisons (#122381)
Add implementations of `cosineSimilarity` and `dotProduct` to query byte vector fields using float vectors
This commit is contained in:
parent
5697f7f016
commit
82668b40f4
|
@ -0,0 +1,6 @@
|
|||
pr: 122381
|
||||
summary: Adds implementations of dotProduct and cosineSimilarity painless methods to operate on float vectors for byte fields
|
||||
area: Vector Search
|
||||
type: enhancement
|
||||
issues:
|
||||
- 117274
|
|
@ -80,6 +80,19 @@ public class ESVectorUtil {
|
|||
return IMPL.ipFloatBit(q, d);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a byte vector.
|
||||
* @param q the query vector
|
||||
* @param d the document vector
|
||||
* @return the inner product of the two vectors
|
||||
*/
|
||||
public static float ipFloatByte(float[] q, byte[] d) {
|
||||
if (q.length != d.length) {
|
||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
|
||||
}
|
||||
return IMPL.ipFloatByte(q, d);
|
||||
}
|
||||
|
||||
/**
|
||||
* AND bit count computed over signed bytes.
|
||||
* Copied from Lucene's XOR implementation
|
||||
|
|
|
@ -39,6 +39,11 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return ipFloatBitImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float ipFloatByte(float[] q, byte[] d) {
|
||||
return ipFloatByteImpl(q, d);
|
||||
}
|
||||
|
||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||
assert q.length == d.length * Byte.SIZE;
|
||||
int acc0 = 0;
|
||||
|
@ -101,4 +106,12 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
public static float ipFloatByteImpl(float[] q, byte[] d) {
|
||||
float ret = 0;
|
||||
for (int i = 0; i < q.length; i++) {
|
||||
ret += q[i] * d[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,4 +18,6 @@ public interface ESVectorUtilSupport {
|
|||
int ipByteBit(byte[] q, byte[] d);
|
||||
|
||||
float ipFloatBit(float[] q, byte[] d);
|
||||
|
||||
float ipFloatByte(float[] q, byte[] d);
|
||||
}
|
||||
|
|
|
@ -58,6 +58,11 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float ipFloatByte(float[] q, byte[] d) {
|
||||
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
|
||||
}
|
||||
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||
|
||||
|
|
|
@ -32,11 +32,28 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
|||
public void testIpFloatBit() {
|
||||
float[] q = new float[16];
|
||||
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
|
||||
random().nextFloat();
|
||||
for (int i = 0; i < q.length; i++) {
|
||||
q[i] = random().nextFloat();
|
||||
}
|
||||
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
|
||||
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
|
||||
}
|
||||
|
||||
public void testIpFloatByte() {
|
||||
float[] q = new float[16];
|
||||
byte[] d = new byte[16];
|
||||
for (int i = 0; i < q.length; i++) {
|
||||
q[i] = random().nextFloat();
|
||||
}
|
||||
random().nextBytes(d);
|
||||
|
||||
float expected = 0;
|
||||
for (int i = 0; i < q.length; i++) {
|
||||
expected += q[i] * d[i];
|
||||
}
|
||||
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
|
||||
}
|
||||
|
||||
public void testBitAndCount() {
|
||||
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
|
||||
}
|
||||
|
|
|
@ -107,6 +107,38 @@ setup:
|
|||
- match: {hits.hits.2._id: "1"}
|
||||
- match: {hits.hits.2._score: 1632.0}
|
||||
---
|
||||
"Dot Product float":
|
||||
- requires:
|
||||
capabilities:
|
||||
- path: /_search
|
||||
capabilities: [byte_float_dot_product_capability]
|
||||
test_runner_features: [capabilities]
|
||||
reason: "float vector queries capability added"
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "dotProduct(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 32865.2}
|
||||
|
||||
- match: {hits.hits.1._id: "3"}
|
||||
- match: {hits.hits.1._score: 21413.4}
|
||||
|
||||
- match: {hits.hits.2._id: "1"}
|
||||
- match: {hits.hits.2._score: 1862.3}
|
||||
---
|
||||
"Cosine Similarity":
|
||||
- do:
|
||||
headers:
|
||||
|
@ -198,3 +230,39 @@ setup:
|
|||
- match: {hits.hits.2._id: "1"}
|
||||
- gte: {hits.hits.2._score: 0.509}
|
||||
- lte: {hits.hits.2._score: 0.512}
|
||||
|
||||
---
|
||||
"Cosine Similarity float":
|
||||
- requires:
|
||||
capabilities:
|
||||
- path: /_search
|
||||
capabilities: [byte_float_dot_product_capability]
|
||||
test_runner_features: [capabilities]
|
||||
reason: "float vector queries capability added"
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "cosineSimilarity(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- gte: {hits.hits.0._score: 0.989}
|
||||
- lte: {hits.hits.0._score: 0.992}
|
||||
|
||||
- match: {hits.hits.1._id: "3"}
|
||||
- gte: {hits.hits.1._score: 0.885}
|
||||
- lte: {hits.hits.1._score: 0.888}
|
||||
|
||||
- match: {hits.hits.2._id: "1"}
|
||||
- gte: {hits.hits.2._score: 0.505}
|
||||
- lte: {hits.hits.2._score: 0.508}
|
||||
|
|
|
@ -346,16 +346,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void checkVectorBounds(float[] vector) {
|
||||
checkNanAndInfinite(vector);
|
||||
|
||||
StringBuilder errorBuilder = null;
|
||||
StringBuilder checkVectorErrors(float[] vector) {
|
||||
StringBuilder errors = checkNanAndInfinite(vector);
|
||||
if (errors != null) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
for (int index = 0; index < vector.length; ++index) {
|
||||
float value = vector[index];
|
||||
|
||||
if (value % 1.0f != 0.0f) {
|
||||
errorBuilder = new StringBuilder(
|
||||
errors = new StringBuilder(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support non-decimal values but found decimal value ["
|
||||
|
@ -368,7 +369,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||
errorBuilder = new StringBuilder(
|
||||
errors = new StringBuilder(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support integers between ["
|
||||
|
@ -385,9 +386,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if (errorBuilder != null) {
|
||||
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
|
||||
}
|
||||
return errors;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -614,8 +613,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void checkVectorBounds(float[] vector) {
|
||||
checkNanAndInfinite(vector);
|
||||
StringBuilder checkVectorErrors(float[] vector) {
|
||||
return checkNanAndInfinite(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -768,16 +767,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void checkVectorBounds(float[] vector) {
|
||||
checkNanAndInfinite(vector);
|
||||
|
||||
StringBuilder errorBuilder = null;
|
||||
StringBuilder checkVectorErrors(float[] vector) {
|
||||
StringBuilder errors = checkNanAndInfinite(vector);
|
||||
if (errors != null) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
for (int index = 0; index < vector.length; ++index) {
|
||||
float value = vector[index];
|
||||
|
||||
if (value % 1.0f != 0.0f) {
|
||||
errorBuilder = new StringBuilder(
|
||||
errors = new StringBuilder(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support non-decimal values but found decimal value ["
|
||||
|
@ -790,7 +790,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||
errorBuilder = new StringBuilder(
|
||||
errors = new StringBuilder(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support integers between ["
|
||||
|
@ -807,9 +807,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if (errorBuilder != null) {
|
||||
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
|
||||
}
|
||||
return errors;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -993,7 +991,44 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
|
||||
|
||||
public abstract void checkVectorBounds(float[] vector);
|
||||
/**
|
||||
* Checks the input {@code vector} is one of the {@code possibleTypes},
|
||||
* and returns the first type that it matches
|
||||
*/
|
||||
public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
|
||||
assert possibleTypes.length != 0;
|
||||
// we're looking for one valid allowed type
|
||||
// assume the types are in order of specificity
|
||||
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
|
||||
for (int i = 0; i < possibleTypes.length; i++) {
|
||||
StringBuilder error = possibleTypes[i].checkVectorErrors(vector);
|
||||
if (error == null) {
|
||||
// this one works - use it
|
||||
return possibleTypes[i];
|
||||
} else {
|
||||
errors[i] = error;
|
||||
}
|
||||
}
|
||||
|
||||
// oh dear, none of the possible types work with this vector. Generate the error message and throw.
|
||||
StringBuilder message = new StringBuilder();
|
||||
for (int i = 0; i < possibleTypes.length; i++) {
|
||||
if (i > 0) {
|
||||
message.append(" ");
|
||||
}
|
||||
message.append("Vector is not a ").append(possibleTypes[i]).append(" vector: ").append(errors[i]);
|
||||
}
|
||||
throw new IllegalArgumentException(appendErrorElements(message, vector).toString());
|
||||
}
|
||||
|
||||
public void checkVectorBounds(float[] vector) {
|
||||
StringBuilder errors = checkVectorErrors(vector);
|
||||
if (errors != null) {
|
||||
throw new IllegalArgumentException(appendErrorElements(errors, vector).toString());
|
||||
}
|
||||
}
|
||||
|
||||
abstract StringBuilder checkVectorErrors(float[] vector);
|
||||
|
||||
abstract void checkVectorMagnitude(
|
||||
VectorSimilarity similarity,
|
||||
|
@ -1017,7 +1052,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
return index;
|
||||
}
|
||||
|
||||
void checkNanAndInfinite(float[] vector) {
|
||||
StringBuilder checkNanAndInfinite(float[] vector) {
|
||||
StringBuilder errorBuilder = null;
|
||||
|
||||
for (int index = 0; index < vector.length; ++index) {
|
||||
|
@ -1044,9 +1079,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if (errorBuilder != null) {
|
||||
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
|
||||
}
|
||||
return errorBuilder;
|
||||
}
|
||||
|
||||
static StringBuilder appendErrorElements(StringBuilder errorBuilder, float[] vector) {
|
||||
|
|
|
@ -25,6 +25,8 @@ public final class SearchCapabilities {
|
|||
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
|
||||
/** Support Byte and Float with Bit dot product. */
|
||||
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
|
||||
/** Support float query vectors on byte vectors */
|
||||
private static final String BYTE_FLOAT_DOT_PRODUCT_CAPABILITY = "byte_float_dot_product_capability";
|
||||
/** Support docvalue_fields parameter for `dense_vector` field. */
|
||||
private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
|
||||
/** Support transforming rank rrf queries to the corresponding rrf retriever. */
|
||||
|
@ -50,6 +52,7 @@ public final class SearchCapabilities {
|
|||
capabilities.add(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY);
|
||||
capabilities.add(BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY);
|
||||
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
|
||||
capabilities.add(BYTE_FLOAT_DOT_PRODUCT_CAPABILITY);
|
||||
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
|
||||
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
|
||||
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
|
||||
|
|
|
@ -11,6 +11,7 @@ package org.elasticsearch.script;
|
|||
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
||||
import org.elasticsearch.script.field.vectors.DenseVector;
|
||||
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
|
||||
|
||||
|
@ -42,7 +43,10 @@ public class VectorScoreScriptUtils {
|
|||
}
|
||||
|
||||
public static class ByteDenseVectorFunction extends DenseVectorFunction {
|
||||
protected final byte[] queryVector;
|
||||
// either byteQueryVector or floatQueryVector will be non-null
|
||||
protected final byte[] byteQueryVector;
|
||||
protected final float[] floatQueryVector;
|
||||
// only valid if byteQueryVector is used
|
||||
protected final float qvMagnitude;
|
||||
|
||||
/**
|
||||
|
@ -51,22 +55,51 @@ public class VectorScoreScriptUtils {
|
|||
* @param scoreScript The script in which this function was referenced.
|
||||
* @param field The vector field.
|
||||
* @param queryVector The query vector.
|
||||
* @param normalizeFloatQuery {@code true} if the query vector is a float vector, then normalize it.
|
||||
* @param allowedTypes The types the vector is allowed to be.
|
||||
*/
|
||||
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
public ByteDenseVectorFunction(
|
||||
ScoreScript scoreScript,
|
||||
DenseVectorDocValuesField field,
|
||||
List<Number> queryVector,
|
||||
boolean normalizeFloatQuery,
|
||||
ElementType... allowedTypes
|
||||
) {
|
||||
super(scoreScript, field);
|
||||
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
|
||||
this.queryVector = new byte[queryVector.size()];
|
||||
float[] validateValues = new float[queryVector.size()];
|
||||
int queryMagnitude = 0;
|
||||
float[] floatValues = new float[queryVector.size()];
|
||||
double queryMagnitude = 0;
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
final Number number = queryVector.get(i);
|
||||
byte value = number.byteValue();
|
||||
this.queryVector[i] = value;
|
||||
float value = queryVector.get(i).floatValue();
|
||||
floatValues[i] = value;
|
||||
queryMagnitude += value * value;
|
||||
validateValues[i] = number.floatValue();
|
||||
}
|
||||
this.qvMagnitude = (float) Math.sqrt(queryMagnitude);
|
||||
field.getElementType().checkVectorBounds(validateValues);
|
||||
queryMagnitude = Math.sqrt(queryMagnitude);
|
||||
|
||||
switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
|
||||
case FLOAT:
|
||||
byteQueryVector = null;
|
||||
floatQueryVector = floatValues;
|
||||
qvMagnitude = -1; // invalid valid, not used for float vectors
|
||||
|
||||
if (normalizeFloatQuery) {
|
||||
for (int i = 0; i < floatQueryVector.length; i++) {
|
||||
floatQueryVector[i] /= (float) queryMagnitude;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case BYTE:
|
||||
floatQueryVector = null;
|
||||
byteQueryVector = new byte[floatValues.length];
|
||||
for (int i = 0; i < floatValues.length; i++) {
|
||||
byteQueryVector[i] = (byte) floatValues[i];
|
||||
}
|
||||
this.qvMagnitude = (float) queryMagnitude;
|
||||
break;
|
||||
default:
|
||||
throw new AssertionError("Unexpected element type");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -78,8 +111,9 @@ public class VectorScoreScriptUtils {
|
|||
*/
|
||||
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
super(scoreScript, field);
|
||||
this.queryVector = queryVector;
|
||||
float queryMagnitude = 0.0f;
|
||||
byteQueryVector = queryVector;
|
||||
floatQueryVector = null;
|
||||
double queryMagnitude = 0.0f;
|
||||
for (byte value : queryVector) {
|
||||
queryMagnitude += value * value;
|
||||
}
|
||||
|
@ -133,7 +167,7 @@ public class VectorScoreScriptUtils {
|
|||
public static class ByteL1Norm extends ByteDenseVectorFunction implements L1NormInterface {
|
||||
|
||||
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
super(scoreScript, field, queryVector, false, ElementType.BYTE);
|
||||
}
|
||||
|
||||
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
|
@ -142,7 +176,7 @@ public class VectorScoreScriptUtils {
|
|||
|
||||
public double l1norm() {
|
||||
setNextVector();
|
||||
return field.get().l1Norm(queryVector);
|
||||
return field.get().l1Norm(byteQueryVector);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,7 +231,7 @@ public class VectorScoreScriptUtils {
|
|||
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
|
||||
|
||||
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
super(scoreScript, field, queryVector, false, ElementType.BYTE);
|
||||
}
|
||||
|
||||
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
|
@ -206,7 +240,7 @@ public class VectorScoreScriptUtils {
|
|||
|
||||
public int hamming() {
|
||||
setNextVector();
|
||||
return field.get().hamming(queryVector);
|
||||
return field.get().hamming(byteQueryVector);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,7 +277,7 @@ public class VectorScoreScriptUtils {
|
|||
public static class ByteL2Norm extends ByteDenseVectorFunction implements L2NormInterface {
|
||||
|
||||
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
super(scoreScript, field, queryVector, false, ElementType.BYTE);
|
||||
}
|
||||
|
||||
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
|
@ -252,7 +286,7 @@ public class VectorScoreScriptUtils {
|
|||
|
||||
public double l2norm() {
|
||||
setNextVector();
|
||||
return field.get().l2Norm(queryVector);
|
||||
return field.get().l2Norm(byteQueryVector);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -388,7 +422,7 @@ public class VectorScoreScriptUtils {
|
|||
public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
|
||||
|
||||
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
super(scoreScript, field, queryVector, false, ElementType.BYTE, ElementType.FLOAT);
|
||||
}
|
||||
|
||||
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
|
@ -397,7 +431,11 @@ public class VectorScoreScriptUtils {
|
|||
|
||||
public double dotProduct() {
|
||||
setNextVector();
|
||||
return field.get().dotProduct(queryVector);
|
||||
if (floatQueryVector != null) {
|
||||
return field.get().dotProduct(floatQueryVector);
|
||||
} else {
|
||||
return field.get().dotProduct(byteQueryVector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -461,7 +499,7 @@ public class VectorScoreScriptUtils {
|
|||
public static class ByteCosineSimilarity extends ByteDenseVectorFunction implements CosineSimilarityInterface {
|
||||
|
||||
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
super(scoreScript, field, queryVector, true, ElementType.BYTE, ElementType.FLOAT);
|
||||
}
|
||||
|
||||
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
|
@ -470,7 +508,12 @@ public class VectorScoreScriptUtils {
|
|||
|
||||
public double cosineSimilarity() {
|
||||
setNextVector();
|
||||
return field.get().cosineSimilarity(queryVector, qvMagnitude);
|
||||
if (floatQueryVector != null) {
|
||||
// float vector is already normalized by the superclass constructor
|
||||
return field.get().cosineSimilarity(floatQueryVector, false);
|
||||
} else {
|
||||
return field.get().cosineSimilarity(byteQueryVector, qvMagnitude);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ package org.elasticsearch.script.field.vectors;
|
|||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.core.SuppressForbidden;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
|
@ -61,7 +62,7 @@ public class ByteBinaryDenseVector implements DenseVector {
|
|||
|
||||
@Override
|
||||
public double dotProduct(float[] queryVector) {
|
||||
throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
|
||||
return ESVectorUtil.ipFloatByte(queryVector, vectorValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -142,7 +143,11 @@ public class ByteBinaryDenseVector implements DenseVector {
|
|||
|
||||
@Override
|
||||
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
|
||||
throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
|
||||
if (normalizeQueryVector) {
|
||||
return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
|
||||
}
|
||||
|
||||
return dotProduct(queryVector) / getMagnitude();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -11,6 +11,7 @@ package org.elasticsearch.script.field.vectors;
|
|||
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.core.SuppressForbidden;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -51,12 +52,12 @@ public class ByteKnnDenseVector implements DenseVector {
|
|||
|
||||
@Override
|
||||
public int dotProduct(byte[] queryVector) {
|
||||
return VectorUtil.dotProduct(docVector, queryVector);
|
||||
return VectorUtil.dotProduct(queryVector, docVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double dotProduct(float[] queryVector) {
|
||||
throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
|
||||
return ESVectorUtil.ipFloatByte(queryVector, docVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -145,7 +146,11 @@ public class ByteKnnDenseVector implements DenseVector {
|
|||
|
||||
@Override
|
||||
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
|
||||
throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
|
||||
if (normalizeQueryVector) {
|
||||
return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
|
||||
}
|
||||
|
||||
return dotProduct(queryVector) / getMagnitude();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
|
|||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HexFormat;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -43,8 +42,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
String fieldName = "vector";
|
||||
int dims = 5;
|
||||
float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f };
|
||||
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
||||
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
|
||||
List<Number> queryVector = List.of(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
||||
List<Number> invalidQueryVector = List.of(0.5, 111.3);
|
||||
|
||||
List<DenseVectorDocValuesField> fields = List.of(
|
||||
new BinaryDenseVectorDocValuesField(
|
||||
|
@ -141,8 +140,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
String fieldName = "vector";
|
||||
int dims = 5;
|
||||
float[] docVector = new float[] { 1, 127, -128, 5, -10 };
|
||||
List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
||||
List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
|
||||
List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
||||
List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
|
||||
String hexidecimalString = HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 });
|
||||
|
||||
List<DenseVectorDocValuesField> fields = List.of(
|
||||
|
@ -183,11 +182,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
for (int i = 0; i < queryVectorArray.length; i++) {
|
||||
queryVectorArray[i] = queryVector.get(i).floatValue();
|
||||
}
|
||||
UnsupportedOperationException uoe = expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() -> field.getInternal().cosineSimilarity(queryVectorArray, true)
|
||||
assertEquals(
|
||||
"cosineSimilarity result is not equal to the expected value!",
|
||||
cosineSimilarityExpected,
|
||||
field.getInternal().cosineSimilarity(queryVectorArray, true),
|
||||
0.001
|
||||
);
|
||||
assertThat(uoe.getMessage(), containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead"));
|
||||
|
||||
// Check each function rejects query vectors with the wrong dimension
|
||||
IllegalArgumentException e = expectThrows(
|
||||
|
@ -240,9 +240,9 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
int dims = 8;
|
||||
float[] docVector = new float[] { 124 };
|
||||
// 124 in binary is b01111100
|
||||
List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
|
||||
List<Number> floatQueryVector = Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
|
||||
List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
|
||||
List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
|
||||
List<Number> floatQueryVector = List.of(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
|
||||
List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
|
||||
String hexidecimalString = HexFormat.of().formatHex(new byte[] { 124 });
|
||||
|
||||
List<DenseVectorDocValuesField> fields = List.of(
|
||||
|
@ -293,8 +293,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
public void testByteVsFloatSimilarity() throws IOException {
|
||||
int dims = 5;
|
||||
float[] docVector = new float[] { 1f, 127f, -128f, 5f, -10f };
|
||||
List<Number> listFloatVector = Arrays.asList(1f, 125f, -12f, 2f, 4f);
|
||||
List<Number> listByteVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
||||
List<Number> listFloatVector = List.of(1f, 125f, -12f, 2f, 4f);
|
||||
List<Number> listByteVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
|
||||
float[] floatVector = new float[] { 1f, 125f, -12f, 2f, 4f };
|
||||
byte[] byteVector = new byte[] { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 };
|
||||
|
||||
|
@ -342,11 +342,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(byteVector));
|
||||
UnsupportedOperationException e = expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() -> field.get().dotProduct(floatVector)
|
||||
);
|
||||
assertThat(e.getMessage(), containsString("use [int dotProduct(byte[] queryVector)] instead"));
|
||||
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
|
||||
}
|
||||
case FLOAT -> {
|
||||
assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
|
||||
|
@ -423,14 +419,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(byteVector), 0.001);
|
||||
UnsupportedOperationException e = expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() -> field.get().cosineSimilarity(floatVector)
|
||||
);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead")
|
||||
);
|
||||
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
|
||||
}
|
||||
case FLOAT -> {
|
||||
assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
|
||||
|
@ -471,81 +460,55 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
|
||||
|
||||
IllegalArgumentException e;
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, greaterThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]"
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
containsString(
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]"
|
||||
),
|
||||
() -> new L1Norm(scoreScript, greaterThanVector, fieldName)
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, greaterThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, greaterThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, greaterThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]"
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
containsString(
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]"
|
||||
),
|
||||
() -> new L2Norm(scoreScript, greaterThanVector, fieldName)
|
||||
);
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, lessThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
containsString(
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
),
|
||||
() -> new L1Norm(scoreScript, lessThanVector, fieldName)
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, lessThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, lessThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, lessThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
containsString(
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
),
|
||||
() -> new L2Norm(scoreScript, lessThanVector, fieldName)
|
||||
);
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, decimalVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
containsString(
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
),
|
||||
() -> new L1Norm(scoreScript, decimalVector, fieldName)
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, decimalVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, decimalVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, decimalVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
containsString(
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
),
|
||||
() -> new L2Norm(scoreScript, decimalVector, fieldName)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -149,11 +149,6 @@ public class DenseVectorTests extends ESTestCase {
|
|||
ByteKnnDenseVector knn = new ByteKnnDenseVector(docVector);
|
||||
UnsupportedOperationException e;
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.dotProduct(queryVector));
|
||||
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.dotProduct((Object) queryVector));
|
||||
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.l1Norm(queryVector));
|
||||
assertEquals(e.getMessage(), "use [int l1Norm(byte[] queryVector)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.l1Norm((Object) queryVector));
|
||||
|
@ -164,18 +159,8 @@ public class DenseVectorTests extends ESTestCase {
|
|||
e = expectThrows(UnsupportedOperationException.class, () -> knn.l2Norm((Object) queryVector));
|
||||
assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity(queryVector));
|
||||
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity((Object) queryVector));
|
||||
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
|
||||
|
||||
ByteBinaryDenseVector binary = new ByteBinaryDenseVector(docVector, new BytesRef(docVector), dims);
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct(queryVector));
|
||||
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct((Object) queryVector));
|
||||
assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.l1Norm(queryVector));
|
||||
assertEquals(e.getMessage(), "use [int l1Norm(byte[] queryVector)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.l1Norm((Object) queryVector));
|
||||
|
@ -185,11 +170,6 @@ public class DenseVectorTests extends ESTestCase {
|
|||
assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.l2Norm((Object) queryVector));
|
||||
assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.cosineSimilarity(queryVector));
|
||||
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> binary.cosineSimilarity((Object) queryVector));
|
||||
assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
|
||||
}
|
||||
|
||||
public void testFloatUnsupported() {
|
||||
|
|
Loading…
Reference in New Issue