mirror of https://mirror.osredm.com/root/redis.git
Fix possible crash with random projection
This commit is contained in:
parent
f330d6175a
commit
31bc07955c
|
@ -0,0 +1,69 @@
|
|||
from test import TestCase, generate_random_vector
|
||||
import struct
|
||||
import redis.exceptions
|
||||
|
||||
class DimensionValidation(TestCase):
|
||||
def getname(self):
|
||||
return "Dimension Validation with Projection"
|
||||
|
||||
def estimated_runtime(self):
|
||||
return 0.5
|
||||
|
||||
def test(self):
|
||||
# Test scenario 1: Create a set with projection
|
||||
original_dim = 100
|
||||
reduced_dim = 50
|
||||
|
||||
# Create the initial vector and set with projection
|
||||
vec1 = generate_random_vector(original_dim)
|
||||
vec1_bytes = struct.pack(f'{original_dim}f', *vec1)
|
||||
|
||||
# Add first vector with projection
|
||||
result = self.redis.execute_command('VADD', self.test_key,
|
||||
'REDUCE', reduced_dim,
|
||||
'FP32', vec1_bytes, f'{self.test_key}:item:1')
|
||||
assert result == 1, "First VADD with REDUCE should return 1"
|
||||
|
||||
# Check VINFO returns the correct projection information
|
||||
info = self.redis.execute_command('VINFO', self.test_key)
|
||||
assert isinstance(info, dict), "VINFO should return a dictionary"
|
||||
assert 'vector-dim' in info, "VINFO should contain vector-dim"
|
||||
assert info['vector-dim'] == reduced_dim, f"Expected reduced dimension {reduced_dim}, got {info['vector-dim']}"
|
||||
assert 'proj-input-dim' in info, "VINFO should contain proj-input-dim"
|
||||
assert info['proj-input-dim'] == original_dim, f"Expected original dimension {original_dim}, got {info['proj-input-dim']}"
|
||||
assert 'proj-enabled' in info, "VINFO should contain proj-enabled"
|
||||
assert info['proj-enabled'] is True, "Projection should be enabled"
|
||||
|
||||
# Test scenario 2: Try adding a mismatched vector - should fail
|
||||
wrong_dim = 80
|
||||
wrong_vec = generate_random_vector(wrong_dim)
|
||||
wrong_vec_bytes = struct.pack(f'{wrong_dim}f', *wrong_vec)
|
||||
|
||||
# This should fail with dimension mismatch error
|
||||
try:
|
||||
self.redis.execute_command('VADD', self.test_key,
|
||||
'REDUCE', reduced_dim,
|
||||
'FP32', wrong_vec_bytes, f'{self.test_key}:item:2')
|
||||
assert False, "VADD with wrong dimension should fail"
|
||||
except redis.exceptions.ResponseError as e:
|
||||
assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error, got: {e}"
|
||||
|
||||
# Test scenario 3: Add a correctly-sized vector
|
||||
vec2 = generate_random_vector(original_dim)
|
||||
vec2_bytes = struct.pack(f'{original_dim}f', *vec2)
|
||||
|
||||
# This should succeed
|
||||
result = self.redis.execute_command('VADD', self.test_key,
|
||||
'REDUCE', reduced_dim,
|
||||
'FP32', vec2_bytes, f'{self.test_key}:item:3')
|
||||
assert result == 1, "VADD with correct dimensions should succeed"
|
||||
|
||||
# Check VSIM also validates input dimensions
|
||||
wrong_query = generate_random_vector(wrong_dim)
|
||||
try:
|
||||
self.redis.execute_command('VSIM', self.test_key,
|
||||
'VALUES', wrong_dim, *[str(x) for x in wrong_query],
|
||||
'COUNT', 10)
|
||||
assert False, "VSIM with wrong dimension should fail"
|
||||
except redis.exceptions.ResponseError as e:
|
||||
assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error in VSIM, got: {e}"
|
36
vset.c
36
vset.c
|
@ -531,10 +531,19 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
|
||||
/* Apply projection if needed */
|
||||
if (vset->proj_matrix) {
|
||||
/* Ensure input dimension matches the projection matrix's expected input dimension */
|
||||
if (dim != vset->proj_input_size) {
|
||||
RedisModule_Free(vec);
|
||||
return RedisModule_ReplyWithErrorFormat(ctx,
|
||||
"ERR Input dimension mismatch for projection - got %d but projection expects %d",
|
||||
(int)dim, (int)vset->proj_input_size);
|
||||
}
|
||||
|
||||
float *projected = applyProjection(vec, vset->proj_matrix,
|
||||
vset->proj_input_size, dim);
|
||||
vset->proj_input_size, dim);
|
||||
RedisModule_Free(vec);
|
||||
vec = projected;
|
||||
dim = vset->hnsw->vector_dim;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -764,6 +773,14 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
/* Apply projection if the set uses it, with the exception
|
||||
* of ELE type, that will already have the right dimension. */
|
||||
if (vset->proj_matrix && dim != vset->hnsw->vector_dim) {
|
||||
/* Ensure input dimension matches the projection matrix's expected input dimension */
|
||||
if (dim != vset->proj_input_size) {
|
||||
RedisModule_Free(vec);
|
||||
return RedisModule_ReplyWithErrorFormat(ctx,
|
||||
"ERR Input dimension mismatch for projection - got %d but projection expects %d",
|
||||
(int)dim, (int)vset->proj_input_size);
|
||||
}
|
||||
|
||||
float *projected = applyProjection(vec, vset->proj_matrix,
|
||||
vset->proj_input_size, dim);
|
||||
RedisModule_Free(vec);
|
||||
|
@ -1251,8 +1268,12 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
|
|||
|
||||
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
|
||||
|
||||
/* Calculate map size based on projection presence */
|
||||
int map_size = 8; /* Base fields */
|
||||
if (vset->proj_matrix) map_size += 2; /* Additional fields for projection */
|
||||
|
||||
/* Reply with hash */
|
||||
RedisModule_ReplyWithMap(ctx, 8);
|
||||
RedisModule_ReplyWithMap(ctx, map_size);
|
||||
|
||||
/* Quantization type */
|
||||
RedisModule_ReplyWithSimpleString(ctx, "quant-type");
|
||||
|
@ -1266,6 +1287,17 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
|
|||
RedisModule_ReplyWithSimpleString(ctx, "vector-dim");
|
||||
RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
|
||||
|
||||
/* Add projection information if present */
|
||||
if (vset->proj_matrix) {
|
||||
/* Original input dimension before projection */
|
||||
RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim");
|
||||
RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size);
|
||||
|
||||
/* Projection enabled flag */
|
||||
RedisModule_ReplyWithSimpleString(ctx, "projection-enabled");
|
||||
RedisModule_ReplyWithBool(ctx, 1);
|
||||
}
|
||||
|
||||
/* Number of elements. */
|
||||
RedisModule_ReplyWithSimpleString(ctx, "size");
|
||||
RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);
|
||||
|
|
Loading…
Reference in New Issue