ESQL: dense_vector cosine similarity function (#130641)

This commit is contained in:
Carlos Delgado 2025-07-15 14:49:25 +02:00 committed by GitHub
parent 730308c689
commit f1ddd4c312
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 842 additions and 15 deletions

View File

@ -0,0 +1 @@
<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="420" height="46" viewbox="0 0 420 46"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m116 0h10m32 0h10m68 0h10m32 0h10m80 0h10m32 0h5"/><rect class="s" x="5" y="5" width="116" height="36"/><text class="k" x="15" y="31">V_COSINE</text><rect class="s" x="131" y="5" width="32" height="36" rx="7"/><text class="syn" x="141" y="31">(</text><rect class="s" x="173" y="5" width="68" height="36" rx="7"/><text class="k" x="183" y="31">left</text><rect class="s" x="251" y="5" width="32" height="36" rx="7"/><text class="syn" x="261" y="31">,</text><rect class="s" x="293" y="5" width="80" height="36" rx="7"/><text class="k" x="303" y="31">right</text><rect class="s" x="383" y="5" width="32" height="36" rx="7"/><text class="syn" x="393" y="31">)</text></svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -0,0 +1,12 @@
{
"comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
"type" : "scalar",
"name" : "v_cosine",
"description" : "Calculates the cosine similarity between two dense_vectors.",
"signatures" : [ ],
"examples" : [
" from colors\n | where color != \"black\"\n | eval similarity = v_cosine(rgb_vector, [0, 255, 255])\n | sort similarity desc, color asc"
],
"preview" : true,
"snapshot_only" : true
}

View File

@ -0,0 +1,11 @@
% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
### V COSINE
Calculates the cosine similarity between two dense_vectors.
```esql
from colors
| where color != "black"
| eval similarity = v_cosine(rgb_vector, [0, 255, 255])
| sort similarity desc, color asc
```

View File

@ -0,0 +1,93 @@
# Tests for cosine similarity function
similarityWithVectorField
required_capability: cosine_vector_similarity_function
// tag::vector-cosine-similarity[]
from colors
| where color != "black"
| eval similarity = v_cosine(rgb_vector, [0, 255, 255])
| sort similarity desc, color asc
// end::vector-cosine-similarity[]
| limit 10
| keep color, similarity
;
// tag::vector-cosine-similarity-result[]
color:text | similarity:double
cyan | 1.0
teal | 1.0
turquoise | 0.9890533685684204
aqua marine | 0.964962363243103
azure | 0.916246771812439
lavender | 0.9136701822280884
mint cream | 0.9122757911682129
honeydew | 0.9122424125671387
gainsboro | 0.9082483053207397
gray | 0.9082483053207397
// end::vector-cosine-similarity-result[]
;
similarityAsPartOfExpression
required_capability: cosine_vector_similarity_function
from colors
| where color != "black"
| eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;
color:text | score:double
cyan | 1.5
teal | 1.5
turquoise | 1.495
aqua marine | 1.482
azure | 1.458
lavender | 1.457
honeydew | 1.456
mint cream | 1.456
gainsboro | 1.454
gray | 1.454
;
similarityWithLiteralVectors
required_capability: cosine_vector_similarity_function
row a = 1
| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3)
| keep similarity
;
similarity:double
0.978
;
similarityWithStats
required_capability: cosine_vector_similarity_function
from colors
| where color != "black"
| eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;
avg:double | min:double | max:double
0.832 | 0.5 | 1.0
;
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: cosine_vector_similarity_function
row vector = [1, 2, 3]
| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3)
| sort similarity desc, color asc
| limit 10
| keep color, similarity
;
similarity:double
0.978
;

View File

@ -0,0 +1,208 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.vector;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.esql.EsqlClientException;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();
params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
return params;
}
private final String functionName;
private final VectorSimilarityFunction similarityFunction;
private int numDims;
public VectorSimilarityFunctionsIT(
@Name("functionName") String functionName,
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
) {
this.functionName = functionName;
this.similarityFunction = similarityFunction;
}
@SuppressWarnings("unchecked")
public void testSimilarityBetweenVectors() {
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, right_vector)
| KEEP left_vector, right_vector, similarity
""", functionName);
try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
float[] right = readVector((List<Float>) values.get(1));
Double similarity = (Double) values.get(2);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, right);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}
@SuppressWarnings("unchecked")
public void testSimilarityBetweenConstantVectorAndField() {
var randomVector = randomVectorArray();
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, %s)
| KEEP left_vector, similarity
""", functionName, Arrays.toString(randomVector));
try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
Double similarity = (Double) values.get(1);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, randomVector);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}
public void testDifferentDimensions() {
var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, %s)
| KEEP left_vector, similarity
""", functionName, Arrays.toString(randomVector));
EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); });
assertTrue(iae.getMessage().contains("Vectors must have the same dimensions"));
}
@SuppressWarnings("unchecked")
public void testSimilarityBetweenConstantVectors() {
var vectorLeft = randomVectorArray();
var vectorRight = randomVectorArray();
var query = String.format(Locale.ROOT, """
ROW a = 1
| EVAL similarity = %s(%s, %s)
| KEEP similarity
""", functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight));
try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(1, valuesList.size());
Double similarity = (Double) valuesList.get(0).get(0);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
}
private static float[] readVector(List<Float> leftVector) {
float[] leftScratch = new float[leftVector.size()];
for (int i = 0; i < leftVector.size(); i++) {
leftScratch[i] = leftVector.get(i);
}
return leftScratch;
}
@Before
public void setup() throws IOException {
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
createIndexWithDenseVector("test");
numDims = randomIntBetween(32, 64) * 2; // min 64, even number
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Float> leftVector = randomVector();
List<Float> rightVector = randomVector();
docs[i] = prepareIndex("test").setId("" + i)
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
}
indexRandom(true, docs);
}
private List<Float> randomVector() {
assert numDims != 0 : "numDims must be set before calling randomVector()";
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(randomFloat());
}
return vector;
}
private float[] randomVectorArray() {
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
return randomVectorArray(numDims);
}
private static float[] randomVectorArray(int dimensions) {
float[] vector = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
vector[j] = randomFloat();
}
return vector;
}
private void createIndexWithDenseVector(String indexName) throws IOException {
var client = client().admin().indices();
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("id")
.field("type", "integer")
.endObject();
createDenseVectorField(mapping, "left_vector");
createDenseVectorField(mapping, "right_vector");
mapping.endObject().endObject();
Settings.Builder settingsBuilder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));
var CreateRequest = client.prepareCreate(indexName)
.setSettings(Settings.builder().put("index.number_of_shards", 1))
.setMapping(mapping)
.setSettings(settingsBuilder.build());
assertAcked(CreateRequest);
}
private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
mapping.endObject();
}
}

View File

@ -1254,7 +1254,12 @@ public class EsqlCapabilities {
* Forbid usage of brackets in unquoted index and enrich policy names
* https://github.com/elastic/elasticsearch/issues/130378
*/
NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES;
NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES,
/*
* Cosine vector similarity function
*/
COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());
private final boolean enabled;

View File

@ -1400,15 +1400,15 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
if (f instanceof In in) {
return processIn(in);
}
if (f instanceof VectorFunction) {
return processVectorFunction(f);
}
if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
return processScalarOrGroupingFunction(f, registry);
}
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
return processBinaryOperator((BinaryOperator) f);
}
if (f instanceof VectorFunction vectorFunction) {
return processVectorFunction(f);
}
return f;
}
@ -1613,6 +1613,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
}
}
@SuppressWarnings("unchecked")
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
List<Expression> args = vectorFunction.arguments();
List<Expression> newArgs = new ArrayList<>();
@ -1620,7 +1621,14 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
if (folded instanceof List) {
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
// Convert to floats so blocks are created accordingly
List<Float> floatVector;
if (arg.dataType() == FLOAT) {
floatVector = (List<Float>) folded;
} else {
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
}
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
newArgs.add(denseVector);
continue;
}

View File

@ -8,7 +8,6 @@
package org.elasticsearch.xpack.esql.expression;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
@ -85,7 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLik
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
@ -259,9 +258,6 @@ public class ExpressionWritables {
}
private static List<NamedWriteableRegistry.Entry> vector() {
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
return List.of(Knn.ENTRY);
}
return List.of();
return VectorWritables.getNamedWritables();
}
}

View File

@ -180,6 +180,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.session.Configuration;
@ -489,7 +490,8 @@ public class EsqlFunctionRegistry {
def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"),
def(StGeohex.class, StGeohex::new, "st_geohex"),
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } };
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } };
}
public EsqlFunctionRegistry snapshotRegistry() {

View File

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import java.io.IOException;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
public class CosineSimilarity extends VectorSimilarityFunction {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"CosineSimilarity",
CosineSimilarity::new
);
static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = COSINE::compare;
@FunctionInfo(
returnType = "double",
preview = true,
description = "Calculates the cosine similarity between two dense_vectors.",
examples = { @Example(file = "vector-cosine-similarity", tag = "vector-cosine-similarity") },
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
)
public CosineSimilarity(
Source source,
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left,
@Param(
name = "right",
type = { "dense_vector" },
description = "second dense_vector to calculate cosine similarity"
) Expression right
) {
super(source, left, right);
}
private CosineSimilarity(StreamInput in) throws IOException {
super(in);
}
@Override
protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
return new CosineSimilarity(source(), newLeft, newRight);
}
@Override
protected SimilarityEvaluatorFunction getSimilarityFunction() {
return SIMILARITY_FUNCTION;
}
@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, CosineSimilarity::new, left(), right());
}
@Override
public String getWriteableName() {
return ENTRY.name;
}
}

View File

@ -0,0 +1,174 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.EsqlClientException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import java.io.IOException;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
/**
* Base class for vector similarity functions, which compute a similarity score between two dense vectors
*/
public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper, VectorFunction {
protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
super(source, left, right);
}
protected VectorSimilarityFunction(StreamInput in) throws IOException {
super(in);
}
@Override
public DataType dataType() {
return DataType.DOUBLE;
}
@Override
protected TypeResolution resolveType() {
if (childrenResolved() == false) {
return new TypeResolution("Unresolved children");
}
return checkDenseVectorParam(left(), FIRST).and(checkDenseVectorParam(right(), SECOND));
}
private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
return isNotNull(param, sourceText(), paramOrdinal).and(
isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector")
);
}
/**
* Functional interface for evaluating the similarity between two float arrays
*/
@FunctionalInterface
public interface SimilarityEvaluatorFunction {
float calculateSimilarity(float[] leftScratch, float[] rightScratch);
}
@Override
public Object fold(FoldContext ctx) {
return EvaluatorMapper.super.fold(source(), ctx);
}
@Override
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
return new SimilarityEvaluatorFactory(
toEvaluator.apply(left()),
toEvaluator.apply(right()),
getSimilarityFunction(),
getClass().getSimpleName() + "Evaluator"
);
}
/**
* Returns the similarity function to be used for evaluating the similarity between two vectors.
*/
protected abstract SimilarityEvaluatorFunction getSimilarityFunction();
private record SimilarityEvaluatorFactory(
EvalOperator.ExpressionEvaluator.Factory left,
EvalOperator.ExpressionEvaluator.Factory right,
SimilarityEvaluatorFunction similarityFunction,
String evaluatorName
) implements EvalOperator.ExpressionEvaluator.Factory {
@Override
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
// TODO check whether to use this custom evaluator or reuse / define an existing one
return new EvalOperator.ExpressionEvaluator() {
@Override
public Block eval(Page page) {
try (
FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page);
FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page)
) {
int positionCount = page.getPositionCount();
int dimensions = 0;
// Get the first non-empty vector to calculate the dimension
for (int p = 0; p < positionCount; p++) {
if (leftBlock.getValueCount(p) != 0) {
dimensions = leftBlock.getValueCount(p);
break;
}
}
if (dimensions == 0) {
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
}
float[] leftScratch = new float[dimensions];
float[] rightScratch = new float[dimensions];
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
for (int p = 0; p < positionCount; p++) {
int dimsLeft = leftBlock.getValueCount(p);
int dimsRight = rightBlock.getValueCount(p);
if (dimsLeft == 0 || dimsRight == 0) {
// A null value on the left or right vector. Similarity is 0
builder.appendDouble(0.0);
continue;
} else if (dimsLeft != dimsRight) {
throw new EsqlClientException(
"Vectors must have the same dimensions; first vector has {}, and second has {}",
dimsLeft,
dimsRight
);
}
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
builder.appendDouble(result);
}
return builder.build().asBlock();
}
}
}
@Override
public String toString() {
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
}
@Override
public void close() {}
};
}
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
for (int i = 0; i < dimensions; i++) {
scratch[i] = block.getFloat(position + i);
}
}
@Override
public String toString() {
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
}
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Defines the named writables for vector functions in ESQL.
*/
public final class VectorWritables {
private VectorWritables() {
// Utility class
throw new UnsupportedOperationException();
}
public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
entries.add(Knn.ENTRY);
}
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
entries.add(CosineSimilarity.ENTRY);
}
return Collections.unmodifiableList(entries);
}
}

View File

@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
@ -92,6 +93,7 @@ import java.io.IOException;
import java.time.Period;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
@ -123,6 +125,7 @@ import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
@ -2337,7 +2340,7 @@ public class AnalyzerTests extends ESTestCase {
assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
}
public void testDenseVectorImplicitCasting() {
public void testDenseVectorImplicitCastingKnn() {
assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
@ -2351,7 +2354,46 @@ public class AnalyzerTests extends ESTestCase {
var field = knn.field();
var queryVector = as(knn.query(), Literal.class);
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f)));
}
public void testDenseVectorImplicitCastingSimilarityFunctions() {
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
}
}
private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
var plan = analyze(String.format(Locale.ROOT, """
from test | eval similarity = %s
""", similarityFunction), "mapping-dense_vector.json");
var limit = as(plan, Limit.class);
var eval = as(limit.child(), Eval.class);
var alias = as(eval.fields().get(0), Alias.class);
assertEquals("similarity", alias.name());
var similarity = as(alias.child(), VectorSimilarityFunction.class);
var left = as(similarity.left(), FieldAttribute.class);
assertEquals("vector", left.name());
var right = as(similarity.right(), Literal.class);
assertThat(right.dataType(), is(DENSE_VECTOR));
assertThat(right.value(), equalTo(expectedElems));
}
public void testNoDenseVectorFailsSimilarityFunction() {
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkNoDenseVectorFailsSimilarityFunction("v_cosine([0, 1, 2], 0.342)");
}
}
private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {
var query = String.format(Locale.ROOT, "row a = 1 | eval similarity = %s", similarityFunction);
VerificationException error = expectThrows(VerificationException.class, () -> analyze(query));
assertThat(
error.getMessage(),
containsString("second argument of [" + similarityFunction + "] must be" + " [dense_vector], found value [0.342] type [double]")
);
}
public void testRateRequiresCounterTypes() {

View File

@ -2300,6 +2300,20 @@ public class VerifierTests extends ESTestCase {
);
}
public void testVectorSimilarityFunctionsNullArgs() throws Exception {
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first");
checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second");
}
}
private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {
assertThat(
error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer),
containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]")
);
}
private void query(String query) {
query(query, defaultAnalyzer);
}

View File

@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
import com.carrotsearch.randomizedtesting.annotations.Name;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.hamcrest.Matcher;
import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
import static org.hamcrest.Matchers.equalTo;
public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase {
protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
this.testCase = testCaseSupplier.get();
}
@Before
public void checkCapability() {
assumeTrue("Similarity function is not enabled", capability().isEnabled());
}
/**
* Get the capability of the vector similarity function to check
*/
protected abstract EsqlCapabilities.Cap capability();
protected static Iterable<Object[]> similarityParameters(
String className,
VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction
) {
final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]";
List<TestCaseSupplier> suppliers = new ArrayList<>();
// Basic test with two dense vectors
suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DENSE_VECTOR), () -> {
int dimensions = between(64, 128);
List<Float> left = randomDenseVector(dimensions);
List<Float> right = randomDenseVector(dimensions);
float[] leftArray = listToFloatArray(left);
float[] rightArray = listToFloatArray(right);
double expected = similarityFunction.calculateSimilarity(leftArray, rightArray);
return new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"),
new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2")
),
evaluatorName,
DOUBLE,
equalTo(expected) // Random vectors should have cosine similarity close to 0
);
}));
return parameterSuppliersFromTypedData(suppliers);
}
private static float[] listToFloatArray(List<Float> floatList) {
float[] floatArray = new float[floatList.size()];
for (int i = 0; i < floatList.size(); i++) {
floatArray[i] = floatList.get(i);
}
return floatArray;
}
protected double calculateSimilarity(List<Float> left, List<Float> right) {
return 0;
}
/**
* @return A random dense vector for testing
* @param dimensions
*/
private static List<Float> randomDenseVector(int dimensions) {
List<Float> vector = new ArrayList<>();
for (int i = 0; i < dimensions; i++) {
vector.add(randomFloat());
}
return vector;
}
@Override
protected Matcher<Object> allNullsMatcher() {
// A null value on the left or right vector. Similarity is 0
return equalTo(0.0);
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.FunctionName;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import java.util.List;
import java.util.function.Supplier;
@FunctionName("v_cosine")
public class CosineSimilarityTests extends AbstractVectorSimilarityFunctionTestCase {
public CosineSimilarityTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
super(testCaseSupplier);
}
@ParametersFactory
public static Iterable<Object[]> parameters() {
return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION);
}
protected EsqlCapabilities.Cap capability() {
return EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION;
}
@Override
protected Expression build(Source source, List<Expression> args) {
return new CosineSimilarity(source, args.get(0), args.get(1));
}
}

View File

@ -41,6 +41,7 @@ setup:
- sum_over_time
- count_over_time
- distinct_over_time
- cosine_vector_similarity_function
reason: "Test that should only be executed on snapshot versions"
- do: {xpack.usage: {}}
@ -130,7 +131,7 @@ setup:
- match: {esql.functions.coalesce: $functions_coalesce}
- gt: {esql.functions.categorize: $functions_categorize}
# Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation.
- length: {esql.functions: 156} # check the "sister" test below for a likely update to the same esql.functions length check
- length: {esql.functions: 157} # check the "sister" test below for a likely update to the same esql.functions length check
---
"Basic ESQL usage output (telemetry) non-snapshot version":