diff --git a/docs/changelog/124581.yaml b/docs/changelog/124581.yaml new file mode 100644 index 000000000000..cc978981b1ef --- /dev/null +++ b/docs/changelog/124581.yaml @@ -0,0 +1,5 @@ +pr: 124581 +summary: New `vector_rescore` parameter as a quantized index type option +area: Vector Search +type: enhancement +issues: [] diff --git a/docs/reference/elasticsearch/mapping-reference/dense-vector.md b/docs/reference/elasticsearch/mapping-reference/dense-vector.md index 13e8706b3c12..ca7124679927 100644 --- a/docs/reference/elasticsearch/mapping-reference/dense-vector.md +++ b/docs/reference/elasticsearch/mapping-reference/dense-vector.md @@ -287,6 +287,14 @@ $$$dense-vector-index-options$$$ `confidence_interval` : (Optional, float) Only applicable to `int8_hnsw`, `int4_hnsw`, `int8_flat`, and `int4_flat` index types. The confidence interval to use when quantizing the vectors. Can be any value between and including `0.90` and `1.0` or exactly `0`. When the value is `0`, this indicates that dynamic quantiles should be calculated for optimized quantization. When between `0.90` and `1.0`, this value restricts the values used when calculating the quantization thresholds. For example, a value of `0.95` will only use the middle 95% of the values when calculating the quantization thresholds (e.g. the highest and lowest 2.5% of values will be ignored). Defaults to `1/(dims + 1)` for `int8` quantized vectors and `0` for `int4` for dynamic quantile calculation. + `rescore_vector` + : (Optional, object) Functionality in [preview]. An optional section that configures automatic vector rescoring on knn queries for the given field. Only applicable to quantized index types. + :::::{dropdown} Properties of `rescore_vector` + `oversample` + : (required, float) The amount to oversample the search results by. This value should be greater than `1.0` and less than `10.0`. The higher the value, the more vectors will be gathered and rescored with the raw values per shard. + : In case a knn query specifies a `rescore_vector` parameter, the query `rescore_vector` parameter will be used instead. + : See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details. + ::::: :::: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index abde3e86dd05..97fd51e2e4aa 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -244,3 +244,93 @@ setup: index: dynamic_dim_bbq_hnsw body: vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index 229d705bc317..fb45521cb47c 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -611,3 +611,92 @@ setup: - match: { hits.hits.0._id: "1"} - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int8_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int8_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int8_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 9b27aea4b1db..1faaba2be11a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -642,3 +642,92 @@ setup: index: dynamic_dim_hnsw_quantized body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int4_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int4_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int4_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 2541de7023bf..4145eabdbcb3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -249,3 +249,93 @@ setup: index: dynamic_dim_bbq_flat body: vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index 0a2da6e14a6a..64bf85344980 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -405,3 +405,93 @@ setup: - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int4_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int4_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int4_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 6b59b8f641ee..1087b5b264cf 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -346,3 +346,93 @@ setup: index: true index_options: type: int8_flat +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: int8_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: int8_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: int8_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 69c9d050d592..8ab1277cd15a 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -151,6 +151,7 @@ public class IndexVersions { public static final IndexVersion TIME_SERIES_ID_DOC_VALUES_SPARSE_INDEX = def(9_012_0_00, Version.LUCENE_10_1_0); public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_KEYWORD = def(9_013_0_00, Version.LUCENE_10_1_0); public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_IP = def(9_014_0_00, Version.LUCENE_10_1_0); + public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = def(9_015_0_00, Version.LUCENE_10_1_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index b21018f7dfb5..fcd54590bac4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -14,6 +14,8 @@ import org.elasticsearch.features.NodeFeature; import java.util.Set; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING; + /** * Spec for mapper-related features. */ @@ -58,7 +60,8 @@ public class MapperFeatures implements FeatureSpecification { SourceFieldMapper.SYNTHETIC_RECOVERY_SOURCE, ObjectMapper.SUBOBJECTS_FALSE_MAPPING_UPDATE_FIX, UKNOWN_FIELD_MAPPING_UPDATE_ERROR_MESSAGE, - DOC_VALUES_SKIPPER + DOC_VALUES_SKIPPER, + RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 54cae36dd164..3e85ef79d2e5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -38,6 +38,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; @@ -73,6 +74,7 @@ import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser.Token; @@ -112,6 +114,9 @@ public class DenseVectorFieldMapper extends FieldMapper { public static final IndexVersion NORMALIZE_COSINE = IndexVersions.NORMALIZED_VECTOR_COSINE; public static final IndexVersion DEFAULT_TO_INT8 = DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersions.V_8_9_0; + public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS; + + public static final NodeFeature RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING = new NodeFeature("mapper.dense_vector.rescore_vector"); public static final String CONTENT_TYPE = "dense_vector"; public static final short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions @@ -210,10 +215,11 @@ public class DenseVectorFieldMapper extends FieldMapper { ? new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + null, null ) : null, - (n, c, o) -> o == null ? null : parseIndexOptions(n, o), + (n, c, o) -> o == null ? null : parseIndexOptions(n, o, indexVersionCreated), m -> toType(m).indexOptions, (b, n, v) -> { if (v != null) { @@ -1258,10 +1264,19 @@ public class DenseVectorFieldMapper extends FieldMapper { } } + abstract static class QuantizedIndexOptions extends IndexOptions { + final RescoreVector rescoreVector; + + QuantizedIndexOptions(VectorIndexType type, RescoreVector rescoreVector) { + super(type); + this.rescoreVector = rescoreVector; + } + } + public enum VectorIndexType { HNSW("hnsw", false) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1288,7 +1303,7 @@ public class DenseVectorFieldMapper extends FieldMapper { }, INT8_HNSW("int8_hnsw", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1304,8 +1319,12 @@ public class DenseVectorFieldMapper extends FieldMapper { if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1319,7 +1338,7 @@ public class DenseVectorFieldMapper extends FieldMapper { } }, INT4_HNSW("int4_hnsw", true) { - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1335,8 +1354,12 @@ public class DenseVectorFieldMapper extends FieldMapper { if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1351,7 +1374,7 @@ public class DenseVectorFieldMapper extends FieldMapper { }, FLAT("flat", false) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); return new FlatIndexOptions(); } @@ -1368,14 +1391,18 @@ public class DenseVectorFieldMapper extends FieldMapper { }, INT8_FLAT("int8_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8FlatIndexOptions(confidenceInterval); + return new Int8FlatIndexOptions(confidenceInterval, rescoreVector); } @Override @@ -1390,14 +1417,18 @@ public class DenseVectorFieldMapper extends FieldMapper { }, INT4_FLAT("int4_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4FlatIndexOptions(confidenceInterval); + return new Int4FlatIndexOptions(confidenceInterval, rescoreVector); } @Override @@ -1412,7 +1443,7 @@ public class DenseVectorFieldMapper extends FieldMapper { }, BBQ_HNSW("bbq_hnsw", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); if (mNode == null) { @@ -1423,8 +1454,12 @@ public class DenseVectorFieldMapper extends FieldMapper { } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQHnswIndexOptions(m, efConstruction); + return new BBQHnswIndexOptions(m, efConstruction, rescoreVector); } @Override @@ -1439,9 +1474,13 @@ public class DenseVectorFieldMapper extends FieldMapper { }, BBQ_FLAT("bbq_flat", true) { @Override - public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { + public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { + RescoreVector rescoreVector = null; + if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { + rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); + } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQFlatIndexOptions(); + return new BBQFlatIndexOptions(rescoreVector); } @Override @@ -1467,7 +1506,7 @@ public class DenseVectorFieldMapper extends FieldMapper { this.quantized = quantized; } - abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap); + abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion); public abstract boolean supportsElementType(ElementType elementType); @@ -1483,11 +1522,11 @@ public class DenseVectorFieldMapper extends FieldMapper { } } - static class Int8FlatIndexOptions extends IndexOptions { + static class Int8FlatIndexOptions extends QuantizedIndexOptions { private final Float confidenceInterval; - Int8FlatIndexOptions(Float confidenceInterval) { - super(VectorIndexType.INT8_FLAT); + Int8FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_FLAT, rescoreVector); this.confidenceInterval = confidenceInterval; } @@ -1498,6 +1537,9 @@ public class DenseVectorFieldMapper extends FieldMapper { if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1511,12 +1553,12 @@ public class DenseVectorFieldMapper extends FieldMapper { @Override boolean doEquals(IndexOptions o) { Int8FlatIndexOptions that = (Int8FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval); + return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(confidenceInterval); + return Objects.hash(confidenceInterval, rescoreVector); } @Override @@ -1567,13 +1609,13 @@ public class DenseVectorFieldMapper extends FieldMapper { } } - static class Int4HnswIndexOptions extends IndexOptions { + static class Int4HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final float confidenceInterval; - Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT4_HNSW); + Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is @@ -1594,6 +1636,9 @@ public class DenseVectorFieldMapper extends FieldMapper { builder.field("m", m); builder.field("ef_construction", efConstruction); builder.field("confidence_interval", confidenceInterval); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1601,12 +1646,15 @@ public class DenseVectorFieldMapper extends FieldMapper { @Override public boolean doEquals(IndexOptions o) { Int4HnswIndexOptions that = (Int4HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return m == that.m + && efConstruction == that.efConstruction + && Objects.equals(confidenceInterval, that.confidenceInterval) + && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1619,6 +1667,8 @@ public class DenseVectorFieldMapper extends FieldMapper { + efConstruction + ", confidence_interval=" + confidenceInterval + + ", rescore_vector=" + + (rescoreVector == null ? "none" : rescoreVector) + "}"; } @@ -1635,11 +1685,11 @@ public class DenseVectorFieldMapper extends FieldMapper { } } - static class Int4FlatIndexOptions extends IndexOptions { + static class Int4FlatIndexOptions extends QuantizedIndexOptions { private final float confidenceInterval; - Int4FlatIndexOptions(Float confidenceInterval) { - super(VectorIndexType.INT4_FLAT); + Int4FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_FLAT, rescoreVector); // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is // effectively required for int4 to behave well across a wide range of data. this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; @@ -1656,6 +1706,9 @@ public class DenseVectorFieldMapper extends FieldMapper { builder.startObject(); builder.field("type", type); builder.field("confidence_interval", confidenceInterval); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1665,17 +1718,17 @@ public class DenseVectorFieldMapper extends FieldMapper { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int4FlatIndexOptions that = (Int4FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval); + return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(confidenceInterval); + return Objects.hash(confidenceInterval, rescoreVector); } @Override public String toString() { - return "{type=" + type + ", confidence_interval=" + confidenceInterval + "}"; + return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; } @Override @@ -1689,13 +1742,13 @@ public class DenseVectorFieldMapper extends FieldMapper { } - static class Int8HnswIndexOptions extends IndexOptions { + static class Int8HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final Float confidenceInterval; - Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT8_HNSW); + Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; this.confidenceInterval = confidenceInterval; @@ -1716,6 +1769,9 @@ public class DenseVectorFieldMapper extends FieldMapper { if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1725,12 +1781,15 @@ public class DenseVectorFieldMapper extends FieldMapper { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int8HnswIndexOptions that = (Int8HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return m == that.m + && efConstruction == that.efConstruction + && Objects.equals(confidenceInterval, that.confidenceInterval) + && Objects.equals(rescoreVector, that.rescoreVector); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } @Override @@ -1743,6 +1802,8 @@ public class DenseVectorFieldMapper extends FieldMapper { + efConstruction + ", confidence_interval=" + confidenceInterval + + ", rescore_vector=" + + (rescoreVector == null ? "none" : rescoreVector) + "}"; } @@ -1824,12 +1885,12 @@ public class DenseVectorFieldMapper extends FieldMapper { } } - static class BBQHnswIndexOptions extends IndexOptions { + static class BBQHnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; - BBQHnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.BBQ_HNSW); + BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; } @@ -1848,12 +1909,12 @@ public class DenseVectorFieldMapper extends FieldMapper { @Override boolean doEquals(IndexOptions other) { BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction; + return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(m, efConstruction); + return Objects.hash(m, efConstruction, rescoreVector); } @Override @@ -1862,6 +1923,9 @@ public class DenseVectorFieldMapper extends FieldMapper { builder.field("type", type); builder.field("m", m); builder.field("ef_construction", efConstruction); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1875,11 +1939,11 @@ public class DenseVectorFieldMapper extends FieldMapper { } } - static class BBQFlatIndexOptions extends IndexOptions { + static class BBQFlatIndexOptions extends QuantizedIndexOptions { private final int CLASS_NAME_HASH = this.getClass().getName().hashCode(); - BBQFlatIndexOptions() { - super(VectorIndexType.BBQ_FLAT); + BBQFlatIndexOptions(RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_FLAT, rescoreVector); } @Override @@ -1907,6 +1971,9 @@ public class DenseVectorFieldMapper extends FieldMapper { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("type", type); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -1920,6 +1987,41 @@ public class DenseVectorFieldMapper extends FieldMapper { } } + record RescoreVector(float oversample) implements ToXContentObject { + static final String NAME = "rescore_vector"; + static final String OVERSAMPLE = "oversample"; + + static RescoreVector fromIndexOptions(Map indexOptionsMap) { + Object rescoreVectorNode = indexOptionsMap.remove(NAME); + if (rescoreVectorNode == null) { + return null; + } + Map mappedNode = XContentMapValues.nodeMapValue(rescoreVectorNode, NAME); + Object oversampleNode = mappedNode.get(OVERSAMPLE); + if (oversampleNode == null) { + throw new IllegalArgumentException("Invalid rescore_vector value. Missing required field " + OVERSAMPLE); + } + return new RescoreVector((float) XContentMapValues.nodeDoubleValue(oversampleNode)); + } + + RescoreVector { + if (oversample < 1) { + throw new IllegalArgumentException("oversample must be greater than 1"); + } + if (oversample > 10) { + throw new IllegalArgumentException("oversample must be less than or equal to 10"); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field(OVERSAMPLE, oversample); + builder.endObject(); + return builder; + } + } + public static final TypeParser PARSER = new TypeParser( (n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE) @@ -2135,7 +2237,7 @@ public class DenseVectorFieldMapper extends FieldMapper { float[] queryVector, int k, int numCands, - Float oversample, + Float queryOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2157,6 +2259,14 @@ public class DenseVectorFieldMapper extends FieldMapper { } int adjustedK = k; + // By default utilize the quantized oversample is configured + // allow the user provided at query time overwrite + Float oversample = queryOversample; + if (oversample == null + && indexOptions instanceof QuantizedIndexOptions quantizedIndexOptions + && quantizedIndexOptions.rescoreVector != null) { + oversample = quantizedIndexOptions.rescoreVector.oversample; + } boolean rescore = needsRescore(oversample); if (rescore) { // Will get k * oversample for rescoring, and get the top k @@ -2352,7 +2462,7 @@ public class DenseVectorFieldMapper extends FieldMapper { return new Builder(leafName(), indexCreatedVersion).init(this); } - private static IndexOptions parseIndexOptions(String fieldName, Object propNode) { + private static IndexOptions parseIndexOptions(String fieldName, Object propNode, IndexVersion indexVersion) { @SuppressWarnings("unchecked") Map indexOptionsMap = (Map) propNode; Object typeNode = indexOptionsMap.remove("type"); @@ -2365,7 +2475,7 @@ public class DenseVectorFieldMapper extends FieldMapper { throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]"); } VectorIndexType parsedType = vectorIndexType.get(); - return parsedType.parseIndexOptions(fieldName, indexOptionsMap); + return parsedType.parseIndexOptions(fieldName, indexOptionsMap, indexVersion); } /** diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 3f574a29469c..036d2e62c8f4 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -59,6 +59,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Set; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; @@ -883,6 +884,120 @@ public class DenseVectorFieldMapperTests extends MapperTestCase { @Override public void testAggregatableConsistency() {} + public void testRescoreVectorForNonQuantized() { + for (String indexType : List.of("hnsw", "flat")) { + Exception e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 1.5f)) + .endObject() + ) + ) + ); + e.getMessage().contains("Mapping definition for [field] has unsupported parameters:"); + } + } + + public void tesetRescoreVectorOldIndexVersion() { + IndexVersion incompatibleVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersionUtils.getLowestReadCompatibleVersion(), + IndexVersionUtils.getPreviousVersion(DenseVectorFieldMapper.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) + ); + for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { + expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + incompatibleVersion, + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 1.5f)) + .endObject() + ) + ) + ); + } + } + + public void testInvalidRescoreVector() { + for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { + Exception e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("foo", 1.5f)) + .endObject() + ) + ) + ); + e.getMessage().contains("Invalid rescore_vector value. Missing required field oversample"); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", "foo")) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 0.1f)) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of()) + .endObject() + ) + ) + ); + e = expectThrows( + MapperParsingException.class, + () -> createDocumentMapper( + fieldMapping( + b -> b.field("type", "dense_vector") + .field("index", true) + .startObject("index_options") + .field("type", indexType) + .field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 10.1f)) + .endObject() + ) + ) + ); + } + } + public void testDims() { { Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5c067cb2d0a2..e98038b7a075 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -49,6 +49,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { this.indexed = randomBoolean(); } + private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { + return new DenseVectorFieldMapper.RescoreVector(randomFloatBetween(1.0F, 10.0F, false)); + } + private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() { return randomFrom( new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), @@ -62,18 +66,30 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.FlatIndexOptions(), - new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), - new DenseVectorFieldMapper.BBQFlatIndexOptions() + new DenseVectorFieldMapper.Int8FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.Int4FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) ); } @@ -82,14 +98,20 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ) ); } @@ -195,6 +217,9 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { queryVector[i] = randomFloat(); } Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); + if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { + query = rescoreKnnVectorQuery.innerQuery(); + } assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -346,6 +371,9 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { queryVector[i] = randomFloat(); } Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); + if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { + query = rescoreKnnVectorQuery.innerQuery(); + } assertThat(query, instanceOf(KnnFloatVectorQuery.class)); }