diff --git a/modules/vector-sets/vset.c b/modules/vector-sets/vset.c index 88e54a64f..e904adeea 100644 --- a/modules/vector-sets/vset.c +++ b/modules/vector-sets/vset.c @@ -1734,15 +1734,15 @@ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { } } -/* Load object from RDB. Please note that we don't do any cleanup - * on errors, and just return NULL, as Redis will abort completely - * not just the module but the server itself in this case. */ +/* Load object from RDB. Recover from recoverable errors (read errors) + * by performing cleanup. */ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { if (encver != 0) return NULL; // Invalid version uint32_t dim = RedisModule_LoadUnsigned(rdb); uint64_t elements = RedisModule_LoadUnsigned(rdb); uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) return NULL; uint32_t quant_type = hnsw_config & 0xff; uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff; @@ -1754,22 +1754,21 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { /* Load projection matrix if present */ uint32_t save_flags = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) goto ioerr; int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX; int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS; if (has_projection) { uint32_t input_dim = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) goto ioerr; uint32_t output_dim = dim; size_t matrix_size = sizeof(float) * input_dim * output_dim; vset->proj_matrix = RedisModule_Alloc(matrix_size); - if (!vset->proj_matrix) { - vectorSetReleaseObject(vset); - return NULL; - } vset->proj_input_size = input_dim; // Load projection matrix as a binary blob char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL); + if (RedisModule_IsIOError(rdb)) goto ioerr; memcpy(vset->proj_matrix, matrix_blob, matrix_size); RedisModule_Free(matrix_blob); } @@ -1777,9 +1776,14 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { while(elements--) { // Load associated string element. RedisModuleString *ele = RedisModule_LoadString(rdb); + if (RedisModule_IsIOError(rdb)) goto ioerr; RedisModuleString *attrib = NULL; if (has_attribs) { attrib = RedisModule_LoadString(rdb); + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + goto ioerr; + } size_t attrlen; RedisModule_StringPtrLen(attrib,&attrlen); if (attrlen == 0) { @@ -1789,6 +1793,11 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { } size_t vector_len; void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len); + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + goto ioerr; + } uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw); if (vector_len != vector_bytes) { RedisModule_LogIOError(rdb,"warning", @@ -1798,9 +1807,25 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { // Load node parameters back. uint32_t params_count = RedisModule_LoadUnsigned(rdb); + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + RedisModule_Free(vector); + goto ioerr; + } + uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t)); - for (uint32_t j = 0; j < params_count; j++) + for (uint32_t j = 0; j < params_count; j++) { + // Ignore loading errors here: handled at the end of the loop. params[j] = RedisModule_LoadUnsigned(rdb); + } + if (RedisModule_IsIOError(rdb)) { + RedisModule_FreeString(NULL,ele); + if (attrib) RedisModule_FreeString(NULL,attrib); + RedisModule_Free(vector); + RedisModule_Free(params); + goto ioerr; + } struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); nv->item = ele; @@ -1809,15 +1834,22 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { if (node == NULL) { RedisModule_LogIOError(rdb,"warning", "Vector set node index loading error"); - return NULL; // Loading error. + return NULL; // Loading error: likely a corruption. } if (nv->attrib) vset->numattribs++; RedisModule_DictSet(vset->dict,ele,node); RedisModule_Free(vector); RedisModule_Free(params); } - hnsw_deserialize_index(vset->hnsw); + if (!hnsw_deserialize_index(vset->hnsw)) goto ioerr; + return vset; + +ioerr: + /* We want to recover from I/O errors and free the partially allocated + * data structure to support diskless replication. */ + vectorSetReleaseObject(vset); + return NULL; } /* Calculate memory usage */ @@ -1944,7 +1976,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1) == REDISMODULE_ERR) return REDISMODULE_ERR; - /* TODO: Added to pass CI, need to make changes in order to support these options */ RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_IO_ERRORS|REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD); RedisModuleTypeMethods tm = { diff --git a/tests/integration/corrupt-dump-fuzzer.tcl b/tests/integration/corrupt-dump-fuzzer.tcl index ed6a15bd1..5c7c9923a 100644 --- a/tests/integration/corrupt-dump-fuzzer.tcl +++ b/tests/integration/corrupt-dump-fuzzer.tcl @@ -15,6 +15,10 @@ if { ! [ catch { proc generate_collections {suffix elements} { set rd [redis_deferring_client] + set numcmd 7 + set has_vsets [server_has_command vadd] + if {$has_vsets} {incr numcmd} + for {set j 0} {$j < $elements} {incr j} { # add both string values and integers if {$j % 2 == 0} {set val $j} else {set val "_$j"} @@ -25,8 +29,11 @@ proc generate_collections {suffix elements} { $rd zadd zset$suffix $j $val $rd sadd set$suffix $val $rd xadd stream$suffix * item 1 value $val + if {$has_vsets} { + $rd vadd vset$suffix VALUES 3 1 1 1 $j + } } - for {set j 0} {$j < $elements * 7} {incr j} { + for {set j 0} {$j < $elements * $numcmd} {incr j} { $rd read ; # Discard replies } $rd close diff --git a/tests/integration/replication.tcl b/tests/integration/replication.tcl index da7983cb7..10e7a1f25 100644 --- a/tests/integration/replication.tcl +++ b/tests/integration/replication.tcl @@ -756,6 +756,8 @@ test {diskless loading short read} { redis.register_function('test', function() return 'hello1' end) } + set has_vector_sets [server_has_command vadd] + for {set k 0} {$k < 3} {incr k} { for {set i 0} {$i < 10} {incr i} { r set "$k int_$i" [expr {int(rand()*10000)}] @@ -769,6 +771,11 @@ test {diskless loading short read} { r zadd "$k zset_large" [expr {rand()}] [string repeat A [expr {int(rand()*1000000)}]] r lpush "$k list_small" [string repeat A [expr {int(rand()*10)}]] r lpush "$k list_large" [string repeat A [expr {int(rand()*1000000)}]] + + if {$has_vector_sets} { + r vadd "$k vector_set" VALUES 3 [expr {rand()}] [expr {rand()}] [expr {rand()}] [string repeat A [expr {int(rand()*1000)}]] + } + for {set j 0} {$j < 10} {incr j} { r xadd "$k stream" * foo "asdf" bar "1234" } diff --git a/tests/support/util.tcl b/tests/support/util.tcl index 3e4b0a896..0d7d88516 100644 --- a/tests/support/util.tcl +++ b/tests/support/util.tcl @@ -738,7 +738,8 @@ proc generate_fuzzy_traffic_on_key {key type duration} { set list_commands {LINDEX LINSERT LLEN LPOP LPOS LPUSH LPUSHX LRANGE LREM LSET LTRIM RPOP RPOPLPUSH RPUSH RPUSHX} set set_commands {SADD SCARD SDIFF SDIFFSTORE SINTER SINTERSTORE SISMEMBER SMEMBERS SMOVE SPOP SRANDMEMBER SREM SSCAN SUNION SUNIONSTORE} set stream_commands {XACK XADD XCLAIM XDEL XGROUP XINFO XLEN XPENDING XRANGE XREAD XREADGROUP XREVRANGE XTRIM} - set commands [dict create string $string_commands hash $hash_commands zset $zset_commands list $list_commands set $set_commands stream $stream_commands] + set vset_commands {VADD VREM} + set commands [dict create string $string_commands hash $hash_commands zset $zset_commands list $list_commands set $set_commands stream $stream_commands vectorset $vset_commands] set cmds [dict get $commands $type] set start_time [clock seconds] @@ -788,6 +789,18 @@ proc generate_fuzzy_traffic_on_key {key type duration} { lappend cmd [randomValue] incr i 4 } + if {$cmd == "VADD"} { + lappend cmd $key + lappend cmd VALUES 3 1 1 1 + lappend cmd [randomValue] + incr i 7 + } + if {$cmd == "VREM"} { + lappend cmd $key + lappend cmd [randomValue] + incr i 2 + } + for {} {$i < $arity} {incr i} { if {$i == $firstkey || $i == $lastkey} { lappend cmd $key @@ -1144,6 +1157,15 @@ proc memory_usage {key} { return $usage } +# Test if the server supports the specified command. +proc server_has_command {cmd_wanted} { + set lowercase_commands {} + foreach cmd [r command list] { + lappend lowercase_commands [string tolower $cmd] + } + expr {[lsearch $lowercase_commands [string tolower $cmd_wanted]] != -1} +} + # forward compatibility, lmap missing in TCL 8.5 proc lmap args { set body [lindex $args end]