ESQL: dense_vector cosine similarity function (#130641)
This commit is contained in:
parent
730308c689
commit
f1ddd4c312
|
@ -0,0 +1 @@
|
|||
<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="420" height="46" viewbox="0 0 420 46"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m116 0h10m32 0h10m68 0h10m32 0h10m80 0h10m32 0h5"/><rect class="s" x="5" y="5" width="116" height="36"/><text class="k" x="15" y="31">V_COSINE</text><rect class="s" x="131" y="5" width="32" height="36" rx="7"/><text class="syn" x="141" y="31">(</text><rect class="s" x="173" y="5" width="68" height="36" rx="7"/><text class="k" x="183" y="31">left</text><rect class="s" x="251" y="5" width="32" height="36" rx="7"/><text class="syn" x="261" y="31">,</text><rect class="s" x="293" y="5" width="80" height="36" rx="7"/><text class="k" x="303" y="31">right</text><rect class="s" x="383" y="5" width="32" height="36" rx="7"/><text class="syn" x="393" y="31">)</text></svg>
|
After Width: | Height: | Size: 1.2 KiB |
12
docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json
generated
Normal file
12
docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json
generated
Normal file
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
|
||||
"type" : "scalar",
|
||||
"name" : "v_cosine",
|
||||
"description" : "Calculates the cosine similarity between two dense_vectors.",
|
||||
"signatures" : [ ],
|
||||
"examples" : [
|
||||
" from colors\n | where color != \"black\"\n | eval similarity = v_cosine(rgb_vector, [0, 255, 255])\n | sort similarity desc, color asc"
|
||||
],
|
||||
"preview" : true,
|
||||
"snapshot_only" : true
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
|
||||
|
||||
### V COSINE
|
||||
Calculates the cosine similarity between two dense_vectors.
|
||||
|
||||
```esql
|
||||
from colors
|
||||
| where color != "black"
|
||||
| eval similarity = v_cosine(rgb_vector, [0, 255, 255])
|
||||
| sort similarity desc, color asc
|
||||
```
|
|
@ -0,0 +1,93 @@
|
|||
# Tests for cosine similarity function
|
||||
|
||||
similarityWithVectorField
|
||||
required_capability: cosine_vector_similarity_function
|
||||
|
||||
// tag::vector-cosine-similarity[]
|
||||
from colors
|
||||
| where color != "black"
|
||||
| eval similarity = v_cosine(rgb_vector, [0, 255, 255])
|
||||
| sort similarity desc, color asc
|
||||
// end::vector-cosine-similarity[]
|
||||
| limit 10
|
||||
| keep color, similarity
|
||||
;
|
||||
|
||||
// tag::vector-cosine-similarity-result[]
|
||||
color:text | similarity:double
|
||||
cyan | 1.0
|
||||
teal | 1.0
|
||||
turquoise | 0.9890533685684204
|
||||
aqua marine | 0.964962363243103
|
||||
azure | 0.916246771812439
|
||||
lavender | 0.9136701822280884
|
||||
mint cream | 0.9122757911682129
|
||||
honeydew | 0.9122424125671387
|
||||
gainsboro | 0.9082483053207397
|
||||
gray | 0.9082483053207397
|
||||
// end::vector-cosine-similarity-result[]
|
||||
;
|
||||
|
||||
similarityAsPartOfExpression
|
||||
required_capability: cosine_vector_similarity_function
|
||||
|
||||
from colors
|
||||
| where color != "black"
|
||||
| eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3)
|
||||
| sort score desc, color asc
|
||||
| limit 10
|
||||
| keep color, score
|
||||
;
|
||||
|
||||
color:text | score:double
|
||||
cyan | 1.5
|
||||
teal | 1.5
|
||||
turquoise | 1.495
|
||||
aqua marine | 1.482
|
||||
azure | 1.458
|
||||
lavender | 1.457
|
||||
honeydew | 1.456
|
||||
mint cream | 1.456
|
||||
gainsboro | 1.454
|
||||
gray | 1.454
|
||||
;
|
||||
|
||||
similarityWithLiteralVectors
|
||||
required_capability: cosine_vector_similarity_function
|
||||
|
||||
row a = 1
|
||||
| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3)
|
||||
| keep similarity
|
||||
;
|
||||
|
||||
similarity:double
|
||||
0.978
|
||||
;
|
||||
|
||||
similarityWithStats
|
||||
required_capability: cosine_vector_similarity_function
|
||||
|
||||
from colors
|
||||
| where color != "black"
|
||||
| eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3)
|
||||
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
|
||||
;
|
||||
|
||||
avg:double | min:double | max:double
|
||||
0.832 | 0.5 | 1.0
|
||||
;
|
||||
|
||||
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
|
||||
similarityWithRow-Ignore
|
||||
required_capability: cosine_vector_similarity_function
|
||||
|
||||
row vector = [1, 2, 3]
|
||||
| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3)
|
||||
| sort similarity desc, color asc
|
||||
| limit 10
|
||||
| keep color, similarity
|
||||
;
|
||||
|
||||
similarity:double
|
||||
0.978
|
||||
;
|
|
@ -0,0 +1,208 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.vector;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.elasticsearch.action.index.IndexRequestBuilder;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xpack.esql.EsqlClientException;
|
||||
import org.elasticsearch.xpack.esql.EsqlTestUtils;
|
||||
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
|
||||
|
||||
public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
|
||||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parameters() throws Exception {
|
||||
List<Object[]> params = new ArrayList<>();
|
||||
|
||||
params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
private final String functionName;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private int numDims;
|
||||
|
||||
public VectorSimilarityFunctionsIT(
|
||||
@Name("functionName") String functionName,
|
||||
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
|
||||
) {
|
||||
this.functionName = functionName;
|
||||
this.similarityFunction = similarityFunction;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSimilarityBetweenVectors() {
|
||||
var query = String.format(Locale.ROOT, """
|
||||
FROM test
|
||||
| EVAL similarity = %s(left_vector, right_vector)
|
||||
| KEEP left_vector, right_vector, similarity
|
||||
""", functionName);
|
||||
|
||||
try (var resp = run(query)) {
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
valuesList.forEach(values -> {
|
||||
float[] left = readVector((List<Float>) values.get(0));
|
||||
float[] right = readVector((List<Float>) values.get(1));
|
||||
Double similarity = (Double) values.get(2);
|
||||
|
||||
assertNotNull(similarity);
|
||||
float expectedSimilarity = similarityFunction.compare(left, right);
|
||||
assertEquals(expectedSimilarity, similarity, 0.0001);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSimilarityBetweenConstantVectorAndField() {
|
||||
var randomVector = randomVectorArray();
|
||||
var query = String.format(Locale.ROOT, """
|
||||
FROM test
|
||||
| EVAL similarity = %s(left_vector, %s)
|
||||
| KEEP left_vector, similarity
|
||||
""", functionName, Arrays.toString(randomVector));
|
||||
|
||||
try (var resp = run(query)) {
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
valuesList.forEach(values -> {
|
||||
float[] left = readVector((List<Float>) values.get(0));
|
||||
Double similarity = (Double) values.get(1);
|
||||
|
||||
assertNotNull(similarity);
|
||||
float expectedSimilarity = similarityFunction.compare(left, randomVector);
|
||||
assertEquals(expectedSimilarity, similarity, 0.0001);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void testDifferentDimensions() {
|
||||
var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
|
||||
var query = String.format(Locale.ROOT, """
|
||||
FROM test
|
||||
| EVAL similarity = %s(left_vector, %s)
|
||||
| KEEP left_vector, similarity
|
||||
""", functionName, Arrays.toString(randomVector));
|
||||
|
||||
EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); });
|
||||
assertTrue(iae.getMessage().contains("Vectors must have the same dimensions"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSimilarityBetweenConstantVectors() {
|
||||
var vectorLeft = randomVectorArray();
|
||||
var vectorRight = randomVectorArray();
|
||||
var query = String.format(Locale.ROOT, """
|
||||
ROW a = 1
|
||||
| EVAL similarity = %s(%s, %s)
|
||||
| KEEP similarity
|
||||
""", functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight));
|
||||
|
||||
try (var resp = run(query)) {
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
assertEquals(1, valuesList.size());
|
||||
|
||||
Double similarity = (Double) valuesList.get(0).get(0);
|
||||
assertNotNull(similarity);
|
||||
float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
|
||||
assertEquals(expectedSimilarity, similarity, 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
private static float[] readVector(List<Float> leftVector) {
|
||||
float[] leftScratch = new float[leftVector.size()];
|
||||
for (int i = 0; i < leftVector.size(); i++) {
|
||||
leftScratch[i] = leftVector.get(i);
|
||||
}
|
||||
return leftScratch;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setup() throws IOException {
|
||||
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
|
||||
|
||||
createIndexWithDenseVector("test");
|
||||
|
||||
numDims = randomIntBetween(32, 64) * 2; // min 64, even number
|
||||
int numDocs = randomIntBetween(10, 100);
|
||||
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
List<Float> leftVector = randomVector();
|
||||
List<Float> rightVector = randomVector();
|
||||
docs[i] = prepareIndex("test").setId("" + i)
|
||||
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
|
||||
}
|
||||
|
||||
indexRandom(true, docs);
|
||||
}
|
||||
|
||||
private List<Float> randomVector() {
|
||||
assert numDims != 0 : "numDims must be set before calling randomVector()";
|
||||
List<Float> vector = new ArrayList<>(numDims);
|
||||
for (int j = 0; j < numDims; j++) {
|
||||
vector.add(randomFloat());
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
private float[] randomVectorArray() {
|
||||
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
|
||||
return randomVectorArray(numDims);
|
||||
}
|
||||
|
||||
private static float[] randomVectorArray(int dimensions) {
|
||||
float[] vector = new float[dimensions];
|
||||
for (int j = 0; j < dimensions; j++) {
|
||||
vector[j] = randomFloat();
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
private void createIndexWithDenseVector(String indexName) throws IOException {
|
||||
var client = client().admin().indices();
|
||||
XContentBuilder mapping = XContentFactory.jsonBuilder()
|
||||
.startObject()
|
||||
.startObject("properties")
|
||||
.startObject("id")
|
||||
.field("type", "integer")
|
||||
.endObject();
|
||||
createDenseVectorField(mapping, "left_vector");
|
||||
createDenseVectorField(mapping, "right_vector");
|
||||
mapping.endObject().endObject();
|
||||
Settings.Builder settingsBuilder = Settings.builder()
|
||||
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
|
||||
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));
|
||||
|
||||
var CreateRequest = client.prepareCreate(indexName)
|
||||
.setSettings(Settings.builder().put("index.number_of_shards", 1))
|
||||
.setMapping(mapping)
|
||||
.setSettings(settingsBuilder.build());
|
||||
assertAcked(CreateRequest);
|
||||
}
|
||||
|
||||
private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
|
||||
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
|
||||
mapping.endObject();
|
||||
}
|
||||
}
|
|
@ -1254,7 +1254,12 @@ public class EsqlCapabilities {
|
|||
* Forbid usage of brackets in unquoted index and enrich policy names
|
||||
* https://github.com/elastic/elasticsearch/issues/130378
|
||||
*/
|
||||
NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES;
|
||||
NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES,
|
||||
|
||||
/*
|
||||
* Cosine vector similarity function
|
||||
*/
|
||||
COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());
|
||||
|
||||
private final boolean enabled;
|
||||
|
||||
|
|
|
@ -1400,15 +1400,15 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
if (f instanceof In in) {
|
||||
return processIn(in);
|
||||
}
|
||||
if (f instanceof VectorFunction) {
|
||||
return processVectorFunction(f);
|
||||
}
|
||||
if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
|
||||
return processScalarOrGroupingFunction(f, registry);
|
||||
}
|
||||
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
|
||||
return processBinaryOperator((BinaryOperator) f);
|
||||
}
|
||||
if (f instanceof VectorFunction vectorFunction) {
|
||||
return processVectorFunction(f);
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -1613,6 +1613,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
|
||||
List<Expression> args = vectorFunction.arguments();
|
||||
List<Expression> newArgs = new ArrayList<>();
|
||||
|
@ -1620,7 +1621,14 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
|
||||
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
|
||||
if (folded instanceof List) {
|
||||
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
|
||||
// Convert to floats so blocks are created accordingly
|
||||
List<Float> floatVector;
|
||||
if (arg.dataType() == FLOAT) {
|
||||
floatVector = (List<Float>) folded;
|
||||
} else {
|
||||
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
|
||||
}
|
||||
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
|
||||
newArgs.add(denseVector);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
package org.elasticsearch.xpack.esql.expression;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
|
||||
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
|
||||
|
@ -85,7 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLik
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
|
||||
|
@ -259,9 +258,6 @@ public class ExpressionWritables {
|
|||
}
|
||||
|
||||
private static List<NamedWriteableRegistry.Entry> vector() {
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
|
||||
return List.of(Knn.ENTRY);
|
||||
}
|
||||
return List.of();
|
||||
return VectorWritables.getNamedWritables();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -180,6 +180,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.parser.ParsingException;
|
||||
import org.elasticsearch.xpack.esql.session.Configuration;
|
||||
|
@ -489,7 +490,8 @@ public class EsqlFunctionRegistry {
|
|||
def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"),
|
||||
def(StGeohex.class, StGeohex::new, "st_geohex"),
|
||||
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
|
||||
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } };
|
||||
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
|
||||
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } };
|
||||
}
|
||||
|
||||
public EsqlFunctionRegistry snapshotRegistry() {
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
|
||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Example;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Param;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
|
||||
public class CosineSimilarity extends VectorSimilarityFunction {
|
||||
|
||||
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
|
||||
Expression.class,
|
||||
"CosineSimilarity",
|
||||
CosineSimilarity::new
|
||||
);
|
||||
static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = COSINE::compare;
|
||||
|
||||
@FunctionInfo(
|
||||
returnType = "double",
|
||||
preview = true,
|
||||
description = "Calculates the cosine similarity between two dense_vectors.",
|
||||
examples = { @Example(file = "vector-cosine-similarity", tag = "vector-cosine-similarity") },
|
||||
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
|
||||
)
|
||||
public CosineSimilarity(
|
||||
Source source,
|
||||
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left,
|
||||
@Param(
|
||||
name = "right",
|
||||
type = { "dense_vector" },
|
||||
description = "second dense_vector to calculate cosine similarity"
|
||||
) Expression right
|
||||
) {
|
||||
super(source, left, right);
|
||||
}
|
||||
|
||||
private CosineSimilarity(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
|
||||
return new CosineSimilarity(source(), newLeft, newRight);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SimilarityEvaluatorFunction getSimilarityFunction() {
|
||||
return SIMILARITY_FUNCTION;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NodeInfo<? extends Expression> info() {
|
||||
return NodeInfo.create(this, CosineSimilarity::new, left(), right());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return ENTRY.name;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.FloatBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator;
|
||||
import org.elasticsearch.xpack.esql.EsqlClientException;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
|
||||
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
|
||||
|
||||
/**
|
||||
* Base class for vector similarity functions, which compute a similarity score between two dense vectors
|
||||
*/
|
||||
public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper, VectorFunction {
|
||||
|
||||
protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
|
||||
super(source, left, right);
|
||||
}
|
||||
|
||||
protected VectorSimilarityFunction(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType dataType() {
|
||||
return DataType.DOUBLE;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TypeResolution resolveType() {
|
||||
if (childrenResolved() == false) {
|
||||
return new TypeResolution("Unresolved children");
|
||||
}
|
||||
|
||||
return checkDenseVectorParam(left(), FIRST).and(checkDenseVectorParam(right(), SECOND));
|
||||
}
|
||||
|
||||
private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
|
||||
return isNotNull(param, sourceText(), paramOrdinal).and(
|
||||
isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Functional interface for evaluating the similarity between two float arrays
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface SimilarityEvaluatorFunction {
|
||||
float calculateSimilarity(float[] leftScratch, float[] rightScratch);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object fold(FoldContext ctx) {
|
||||
return EvaluatorMapper.super.fold(source(), ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
|
||||
return new SimilarityEvaluatorFactory(
|
||||
toEvaluator.apply(left()),
|
||||
toEvaluator.apply(right()),
|
||||
getSimilarityFunction(),
|
||||
getClass().getSimpleName() + "Evaluator"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the similarity function to be used for evaluating the similarity between two vectors.
|
||||
*/
|
||||
protected abstract SimilarityEvaluatorFunction getSimilarityFunction();
|
||||
|
||||
private record SimilarityEvaluatorFactory(
|
||||
EvalOperator.ExpressionEvaluator.Factory left,
|
||||
EvalOperator.ExpressionEvaluator.Factory right,
|
||||
SimilarityEvaluatorFunction similarityFunction,
|
||||
String evaluatorName
|
||||
) implements EvalOperator.ExpressionEvaluator.Factory {
|
||||
|
||||
@Override
|
||||
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
|
||||
// TODO check whether to use this custom evaluator or reuse / define an existing one
|
||||
return new EvalOperator.ExpressionEvaluator() {
|
||||
@Override
|
||||
public Block eval(Page page) {
|
||||
try (
|
||||
FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page);
|
||||
FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page)
|
||||
) {
|
||||
int positionCount = page.getPositionCount();
|
||||
int dimensions = 0;
|
||||
// Get the first non-empty vector to calculate the dimension
|
||||
for (int p = 0; p < positionCount; p++) {
|
||||
if (leftBlock.getValueCount(p) != 0) {
|
||||
dimensions = leftBlock.getValueCount(p);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (dimensions == 0) {
|
||||
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
|
||||
}
|
||||
|
||||
float[] leftScratch = new float[dimensions];
|
||||
float[] rightScratch = new float[dimensions];
|
||||
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
|
||||
for (int p = 0; p < positionCount; p++) {
|
||||
int dimsLeft = leftBlock.getValueCount(p);
|
||||
int dimsRight = rightBlock.getValueCount(p);
|
||||
|
||||
if (dimsLeft == 0 || dimsRight == 0) {
|
||||
// A null value on the left or right vector. Similarity is 0
|
||||
builder.appendDouble(0.0);
|
||||
continue;
|
||||
} else if (dimsLeft != dimsRight) {
|
||||
throw new EsqlClientException(
|
||||
"Vectors must have the same dimensions; first vector has {}, and second has {}",
|
||||
dimsLeft,
|
||||
dimsRight
|
||||
);
|
||||
}
|
||||
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
|
||||
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
|
||||
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
|
||||
builder.appendDouble(result);
|
||||
}
|
||||
return builder.build().asBlock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
};
|
||||
}
|
||||
|
||||
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
scratch[i] = block.getFloat(position + i);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Defines the named writables for vector functions in ESQL.
|
||||
*/
|
||||
public final class VectorWritables {
|
||||
|
||||
private VectorWritables() {
|
||||
// Utility class
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
|
||||
entries.add(Knn.ENTRY);
|
||||
}
|
||||
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
||||
entries.add(CosineSimilarity.ENTRY);
|
||||
}
|
||||
|
||||
return Collections.unmodifiableList(entries);
|
||||
}
|
||||
}
|
|
@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
|
||||
|
@ -92,6 +93,7 @@ import java.io.IOException;
|
|||
import java.time.Period;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
@ -123,6 +125,7 @@ import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
|
|||
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
|
||||
|
@ -2337,7 +2340,7 @@ public class AnalyzerTests extends ESTestCase {
|
|||
assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
|
||||
}
|
||||
|
||||
public void testDenseVectorImplicitCasting() {
|
||||
public void testDenseVectorImplicitCastingKnn() {
|
||||
assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
|
||||
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
|
||||
|
||||
|
@ -2351,7 +2354,46 @@ public class AnalyzerTests extends ESTestCase {
|
|||
var field = knn.field();
|
||||
var queryVector = as(knn.query(), Literal.class);
|
||||
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
|
||||
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
|
||||
assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f)));
|
||||
}
|
||||
|
||||
public void testDenseVectorImplicitCastingSimilarityFunctions() {
|
||||
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
||||
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
|
||||
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
|
||||
}
|
||||
}
|
||||
|
||||
private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
|
||||
var plan = analyze(String.format(Locale.ROOT, """
|
||||
from test | eval similarity = %s
|
||||
""", similarityFunction), "mapping-dense_vector.json");
|
||||
|
||||
var limit = as(plan, Limit.class);
|
||||
var eval = as(limit.child(), Eval.class);
|
||||
var alias = as(eval.fields().get(0), Alias.class);
|
||||
assertEquals("similarity", alias.name());
|
||||
var similarity = as(alias.child(), VectorSimilarityFunction.class);
|
||||
var left = as(similarity.left(), FieldAttribute.class);
|
||||
assertEquals("vector", left.name());
|
||||
var right = as(similarity.right(), Literal.class);
|
||||
assertThat(right.dataType(), is(DENSE_VECTOR));
|
||||
assertThat(right.value(), equalTo(expectedElems));
|
||||
}
|
||||
|
||||
public void testNoDenseVectorFailsSimilarityFunction() {
|
||||
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
||||
checkNoDenseVectorFailsSimilarityFunction("v_cosine([0, 1, 2], 0.342)");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {
|
||||
var query = String.format(Locale.ROOT, "row a = 1 | eval similarity = %s", similarityFunction);
|
||||
VerificationException error = expectThrows(VerificationException.class, () -> analyze(query));
|
||||
assertThat(
|
||||
error.getMessage(),
|
||||
containsString("second argument of [" + similarityFunction + "] must be" + " [dense_vector], found value [0.342] type [double]")
|
||||
);
|
||||
}
|
||||
|
||||
public void testRateRequiresCounterTypes() {
|
||||
|
|
|
@ -2300,6 +2300,20 @@ public class VerifierTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testVectorSimilarityFunctionsNullArgs() throws Exception {
|
||||
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
||||
checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first");
|
||||
checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {
|
||||
assertThat(
|
||||
error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer),
|
||||
containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]")
|
||||
);
|
||||
}
|
||||
|
||||
private void query(String query) {
|
||||
query(query, defaultAnalyzer);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
|
||||
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase {
|
||||
|
||||
protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
|
||||
this.testCase = testCaseSupplier.get();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void checkCapability() {
|
||||
assumeTrue("Similarity function is not enabled", capability().isEnabled());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the capability of the vector similarity function to check
|
||||
*/
|
||||
protected abstract EsqlCapabilities.Cap capability();
|
||||
|
||||
protected static Iterable<Object[]> similarityParameters(
|
||||
String className,
|
||||
VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction
|
||||
) {
|
||||
|
||||
final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]";
|
||||
|
||||
List<TestCaseSupplier> suppliers = new ArrayList<>();
|
||||
|
||||
// Basic test with two dense vectors
|
||||
suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DENSE_VECTOR), () -> {
|
||||
int dimensions = between(64, 128);
|
||||
List<Float> left = randomDenseVector(dimensions);
|
||||
List<Float> right = randomDenseVector(dimensions);
|
||||
float[] leftArray = listToFloatArray(left);
|
||||
float[] rightArray = listToFloatArray(right);
|
||||
double expected = similarityFunction.calculateSimilarity(leftArray, rightArray);
|
||||
return new TestCaseSupplier.TestCase(
|
||||
List.of(
|
||||
new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"),
|
||||
new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2")
|
||||
),
|
||||
evaluatorName,
|
||||
DOUBLE,
|
||||
equalTo(expected) // Random vectors should have cosine similarity close to 0
|
||||
);
|
||||
}));
|
||||
|
||||
return parameterSuppliersFromTypedData(suppliers);
|
||||
}
|
||||
|
||||
private static float[] listToFloatArray(List<Float> floatList) {
|
||||
float[] floatArray = new float[floatList.size()];
|
||||
for (int i = 0; i < floatList.size(); i++) {
|
||||
floatArray[i] = floatList.get(i);
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
|
||||
protected double calculateSimilarity(List<Float> left, List<Float> right) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A random dense vector for testing
|
||||
* @param dimensions
|
||||
*/
|
||||
private static List<Float> randomDenseVector(int dimensions) {
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
vector.add(randomFloat());
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Matcher<Object> allNullsMatcher() {
|
||||
// A null value on the left or right vector. Similarity is 0
|
||||
return equalTo(0.0);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionName;
|
||||
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
@FunctionName("v_cosine")
|
||||
public class CosineSimilarityTests extends AbstractVectorSimilarityFunctionTestCase {
|
||||
|
||||
public CosineSimilarityTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
|
||||
super(testCaseSupplier);
|
||||
}
|
||||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parameters() {
|
||||
return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION);
|
||||
}
|
||||
|
||||
protected EsqlCapabilities.Cap capability() {
|
||||
return EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression build(Source source, List<Expression> args) {
|
||||
return new CosineSimilarity(source, args.get(0), args.get(1));
|
||||
}
|
||||
}
|
|
@ -41,6 +41,7 @@ setup:
|
|||
- sum_over_time
|
||||
- count_over_time
|
||||
- distinct_over_time
|
||||
- cosine_vector_similarity_function
|
||||
reason: "Test that should only be executed on snapshot versions"
|
||||
|
||||
- do: {xpack.usage: {}}
|
||||
|
@ -130,7 +131,7 @@ setup:
|
|||
- match: {esql.functions.coalesce: $functions_coalesce}
|
||||
- gt: {esql.functions.categorize: $functions_categorize}
|
||||
# Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation.
|
||||
- length: {esql.functions: 156} # check the "sister" test below for a likely update to the same esql.functions length check
|
||||
- length: {esql.functions: 157} # check the "sister" test below for a likely update to the same esql.functions length check
|
||||
|
||||
---
|
||||
"Basic ESQL usage output (telemetry) non-snapshot version":
|
||||
|
|
Loading…
Reference in New Issue