Fix possible crash with random projection

This commit is contained in:
Rowan Trollope 2025-03-22 09:11:20 -07:00
parent f330d6175a
commit 31bc07955c
2 changed files with 103 additions and 2 deletions

View File

@ -0,0 +1,69 @@
from test import TestCase, generate_random_vector
import struct
import redis.exceptions
class DimensionValidation(TestCase):
def getname(self):
return "Dimension Validation with Projection"
def estimated_runtime(self):
return 0.5
def test(self):
# Test scenario 1: Create a set with projection
original_dim = 100
reduced_dim = 50
# Create the initial vector and set with projection
vec1 = generate_random_vector(original_dim)
vec1_bytes = struct.pack(f'{original_dim}f', *vec1)
# Add first vector with projection
result = self.redis.execute_command('VADD', self.test_key,
'REDUCE', reduced_dim,
'FP32', vec1_bytes, f'{self.test_key}:item:1')
assert result == 1, "First VADD with REDUCE should return 1"
# Check VINFO returns the correct projection information
info = self.redis.execute_command('VINFO', self.test_key)
assert isinstance(info, dict), "VINFO should return a dictionary"
assert 'vector-dim' in info, "VINFO should contain vector-dim"
assert info['vector-dim'] == reduced_dim, f"Expected reduced dimension {reduced_dim}, got {info['vector-dim']}"
assert 'proj-input-dim' in info, "VINFO should contain proj-input-dim"
assert info['proj-input-dim'] == original_dim, f"Expected original dimension {original_dim}, got {info['proj-input-dim']}"
assert 'proj-enabled' in info, "VINFO should contain proj-enabled"
assert info['proj-enabled'] is True, "Projection should be enabled"
# Test scenario 2: Try adding a mismatched vector - should fail
wrong_dim = 80
wrong_vec = generate_random_vector(wrong_dim)
wrong_vec_bytes = struct.pack(f'{wrong_dim}f', *wrong_vec)
# This should fail with dimension mismatch error
try:
self.redis.execute_command('VADD', self.test_key,
'REDUCE', reduced_dim,
'FP32', wrong_vec_bytes, f'{self.test_key}:item:2')
assert False, "VADD with wrong dimension should fail"
except redis.exceptions.ResponseError as e:
assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error, got: {e}"
# Test scenario 3: Add a correctly-sized vector
vec2 = generate_random_vector(original_dim)
vec2_bytes = struct.pack(f'{original_dim}f', *vec2)
# This should succeed
result = self.redis.execute_command('VADD', self.test_key,
'REDUCE', reduced_dim,
'FP32', vec2_bytes, f'{self.test_key}:item:3')
assert result == 1, "VADD with correct dimensions should succeed"
# Check VSIM also validates input dimensions
wrong_query = generate_random_vector(wrong_dim)
try:
self.redis.execute_command('VSIM', self.test_key,
'VALUES', wrong_dim, *[str(x) for x in wrong_query],
'COUNT', 10)
assert False, "VSIM with wrong dimension should fail"
except redis.exceptions.ResponseError as e:
assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error in VSIM, got: {e}"

36
vset.c
View File

@ -531,10 +531,19 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
/* Apply projection if needed */
if (vset->proj_matrix) {
/* Ensure input dimension matches the projection matrix's expected input dimension */
if (dim != vset->proj_input_size) {
RedisModule_Free(vec);
return RedisModule_ReplyWithErrorFormat(ctx,
"ERR Input dimension mismatch for projection - got %d but projection expects %d",
(int)dim, (int)vset->proj_input_size);
}
float *projected = applyProjection(vec, vset->proj_matrix,
vset->proj_input_size, dim);
vset->proj_input_size, dim);
RedisModule_Free(vec);
vec = projected;
dim = vset->hnsw->vector_dim;
}
}
@ -764,6 +773,14 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
/* Apply projection if the set uses it, with the exception
* of ELE type, that will already have the right dimension. */
if (vset->proj_matrix && dim != vset->hnsw->vector_dim) {
/* Ensure input dimension matches the projection matrix's expected input dimension */
if (dim != vset->proj_input_size) {
RedisModule_Free(vec);
return RedisModule_ReplyWithErrorFormat(ctx,
"ERR Input dimension mismatch for projection - got %d but projection expects %d",
(int)dim, (int)vset->proj_input_size);
}
float *projected = applyProjection(vec, vset->proj_matrix,
vset->proj_input_size, dim);
RedisModule_Free(vec);
@ -1251,8 +1268,12 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
/* Calculate map size based on projection presence */
int map_size = 8; /* Base fields */
if (vset->proj_matrix) map_size += 2; /* Additional fields for projection */
/* Reply with hash */
RedisModule_ReplyWithMap(ctx, 8);
RedisModule_ReplyWithMap(ctx, map_size);
/* Quantization type */
RedisModule_ReplyWithSimpleString(ctx, "quant-type");
@ -1266,6 +1287,17 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
RedisModule_ReplyWithSimpleString(ctx, "vector-dim");
RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
/* Add projection information if present */
if (vset->proj_matrix) {
/* Original input dimension before projection */
RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim");
RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size);
/* Projection enabled flag */
RedisModule_ReplyWithSimpleString(ctx, "projection-enabled");
RedisModule_ReplyWithBool(ctx, 1);
}
/* Number of elements. */
RedisModule_ReplyWithSimpleString(ctx, "size");
RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);