From 31bc07955ca2458e0df4c7295bfc3f6c49c2cf8c Mon Sep 17 00:00:00 2001 From: Rowan Trollope Date: Sat, 22 Mar 2025 09:11:20 -0700 Subject: [PATCH] Fix possible crash with random projection --- tests/dimension_validation.py | 69 +++++++++++++++++++++++++++++++++++ vset.c | 36 +++++++++++++++++- 2 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 tests/dimension_validation.py diff --git a/tests/dimension_validation.py b/tests/dimension_validation.py new file mode 100644 index 000000000..a90be65bf --- /dev/null +++ b/tests/dimension_validation.py @@ -0,0 +1,69 @@ +from test import TestCase, generate_random_vector +import struct +import redis.exceptions + +class DimensionValidation(TestCase): + def getname(self): + return "Dimension Validation with Projection" + + def estimated_runtime(self): + return 0.5 + + def test(self): + # Test scenario 1: Create a set with projection + original_dim = 100 + reduced_dim = 50 + + # Create the initial vector and set with projection + vec1 = generate_random_vector(original_dim) + vec1_bytes = struct.pack(f'{original_dim}f', *vec1) + + # Add first vector with projection + result = self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', vec1_bytes, f'{self.test_key}:item:1') + assert result == 1, "First VADD with REDUCE should return 1" + + # Check VINFO returns the correct projection information + info = self.redis.execute_command('VINFO', self.test_key) + assert isinstance(info, dict), "VINFO should return a dictionary" + assert 'vector-dim' in info, "VINFO should contain vector-dim" + assert info['vector-dim'] == reduced_dim, f"Expected reduced dimension {reduced_dim}, got {info['vector-dim']}" + assert 'proj-input-dim' in info, "VINFO should contain proj-input-dim" + assert info['proj-input-dim'] == original_dim, f"Expected original dimension {original_dim}, got {info['proj-input-dim']}" + assert 'proj-enabled' in info, "VINFO should contain proj-enabled" + assert info['proj-enabled'] is True, "Projection should be enabled" + + # Test scenario 2: Try adding a mismatched vector - should fail + wrong_dim = 80 + wrong_vec = generate_random_vector(wrong_dim) + wrong_vec_bytes = struct.pack(f'{wrong_dim}f', *wrong_vec) + + # This should fail with dimension mismatch error + try: + self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', wrong_vec_bytes, f'{self.test_key}:item:2') + assert False, "VADD with wrong dimension should fail" + except redis.exceptions.ResponseError as e: + assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error, got: {e}" + + # Test scenario 3: Add a correctly-sized vector + vec2 = generate_random_vector(original_dim) + vec2_bytes = struct.pack(f'{original_dim}f', *vec2) + + # This should succeed + result = self.redis.execute_command('VADD', self.test_key, + 'REDUCE', reduced_dim, + 'FP32', vec2_bytes, f'{self.test_key}:item:3') + assert result == 1, "VADD with correct dimensions should succeed" + + # Check VSIM also validates input dimensions + wrong_query = generate_random_vector(wrong_dim) + try: + self.redis.execute_command('VSIM', self.test_key, + 'VALUES', wrong_dim, *[str(x) for x in wrong_query], + 'COUNT', 10) + assert False, "VSIM with wrong dimension should fail" + except redis.exceptions.ResponseError as e: + assert "Input dimension mismatch for projection" in str(e), f"Expected dimension mismatch error in VSIM, got: {e}" \ No newline at end of file diff --git a/vset.c b/vset.c index 34df75a1e..ef00c53cc 100644 --- a/vset.c +++ b/vset.c @@ -531,10 +531,19 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { /* Apply projection if needed */ if (vset->proj_matrix) { + /* Ensure input dimension matches the projection matrix's expected input dimension */ + if (dim != vset->proj_input_size) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Input dimension mismatch for projection - got %d but projection expects %d", + (int)dim, (int)vset->proj_input_size); + } + float *projected = applyProjection(vec, vset->proj_matrix, - vset->proj_input_size, dim); + vset->proj_input_size, dim); RedisModule_Free(vec); vec = projected; + dim = vset->hnsw->vector_dim; } } @@ -764,6 +773,14 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { /* Apply projection if the set uses it, with the exception * of ELE type, that will already have the right dimension. */ if (vset->proj_matrix && dim != vset->hnsw->vector_dim) { + /* Ensure input dimension matches the projection matrix's expected input dimension */ + if (dim != vset->proj_input_size) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Input dimension mismatch for projection - got %d but projection expects %d", + (int)dim, (int)vset->proj_input_size); + } + float *projected = applyProjection(vec, vset->proj_matrix, vset->proj_input_size, dim); RedisModule_Free(vec); @@ -1251,8 +1268,12 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + /* Calculate map size based on projection presence */ + int map_size = 8; /* Base fields */ + if (vset->proj_matrix) map_size += 2; /* Additional fields for projection */ + /* Reply with hash */ - RedisModule_ReplyWithMap(ctx, 8); + RedisModule_ReplyWithMap(ctx, map_size); /* Quantization type */ RedisModule_ReplyWithSimpleString(ctx, "quant-type"); @@ -1266,6 +1287,17 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RedisModule_ReplyWithSimpleString(ctx, "vector-dim"); RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); + /* Add projection information if present */ + if (vset->proj_matrix) { + /* Original input dimension before projection */ + RedisModule_ReplyWithSimpleString(ctx, "projection-input-dim"); + RedisModule_ReplyWithLongLong(ctx, vset->proj_input_size); + + /* Projection enabled flag */ + RedisModule_ReplyWithSimpleString(ctx, "projection-enabled"); + RedisModule_ReplyWithBool(ctx, 1); + } + /* Number of elements. */ RedisModule_ReplyWithSimpleString(ctx, "size"); RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count);