Fix README conflict.

This commit is contained in:
antirez 2025-03-03 13:12:25 +01:00
commit b49bc14f96
11 changed files with 5835 additions and 70 deletions

View File

@ -1,2 +1,2 @@
This code is Copyright (C) 2024-2025 Salvatore Sanfilippo.
This code is Copyright (c) 2024-Present, Redis Ltd.
All Rights Reserved.

View File

@ -53,9 +53,9 @@ all: vset.so
.c.xo:
$(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@
vset.xo: redismodule.h
vset.xo: redismodule.h expr.c
vset.so: vset.xo hnsw.xo
vset.so: vset.xo hnsw.xo cJSON.xo
$(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc
# Example sources / objects

238
README.md
View File

@ -2,12 +2,14 @@ This module implements Vector Sets for Redis, a new Redis data type similar
to Sorted Sets but having string elements associated to a vector instead of
a score. The fundamental goal of Vector Sets is to make possible adding items,
and later get a subset of the added items that are the most similar to a
specified vector (often a learned embedding) of the most similar to the vector
specified vector (often a learned embedding), or the most similar to the vector
of an element that is already part of the Vector Set.
Moreover, Vector sets implement optional hybrid search capabilities: it is possible to associate attributes to all or to a subset of elements in the set, and then, using the `FILTER` option of the `VSIM` command, to ask for items similar to a given vector but also passing a filter specified as a simple mathematical expression (Like `".year > 1950"` or similar).
## Installation
Buil with:
Build with:
make
@ -27,7 +29,8 @@ The execute the tests with:
**VADD: add items into a vector set**
VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN] [EF build-exploration-factor]
VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN]
[EF build-exploration-factor] [SETATTR <attributes>]
Add a new element into the vector set specified by the key.
The vector can be provided as FP32 blob of values, or as floating point
@ -35,27 +38,31 @@ numbers as strings, prefixed by the number of elements (3 in the example):
VADD mykey VALUES 3 0.1 1.2 0.5 my-element
The `REDUCE` option implements random projection, in order to reduce the
Meaning of the options:
`REDUCE` implements random projection, in order to reduce the
dimensionality of the vector. The projection matrix is saved and reloaded
along with the vector set.
The `CAS` option performs the operation partially using threads, in a
`CAS` performs the operation partially using threads, in a
check-and-set style. The neighbor candidates collection, which is slow, is
performed in the background, while the command is executed in the main thread.
The `NOQUANT` option forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default.
`NOQUANT` forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default.
The `BIN` option forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality.
`BIN` forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality.
The `Q8` option forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format.
`Q8` forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format.
The `EF` option plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches.
`EF` plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches.
`SETATTR` associates attributes to the newly created entry or update the entry attributes (if it already exists). It is the same as calling the `VSETATTR` attribute separately, so please check the documentation of that command in the hybrid search section of this documentation.
**VSIM: return elements by vector similarity**
VSIM key [ELE|FP32|VALUES] <vector or element> [WITHSCORES] [COUNT num] [EF search-exploration-factor]
VSIM key [ELE|FP32|VALUES] <vector or element> [WITHSCORES] [COUNT num] [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort]
The command returns similar vectors, in the example instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements associated with a vector similar to a given element already in the sorted set:
The command returns similar vectors, for simplicity (and verbosity) in the following example, instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements having a vector similar to a given element already in the sorted set:
> VSIM word_embeddings ELE apple
1) "apple"
@ -81,6 +88,8 @@ It is possible to specify a `COUNT` and also to get the similarity score (from 1
The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000.
For `FILTER` and `FILTER-EF` options, please check the hybrid search section of this documentation.
**VDIM: return the dimension of the vectors inside the vector set**
VDIM keyname
@ -180,17 +189,218 @@ Example:
11) hnsw-max-node-uid
12) (integer) 3000000
## Known bugs
**VSETATTR: associate or remove the JSON attributes of elements**
VSETATTR key element "{... json ...}"
Each element of a vector set can be optionally associated with a JSON string
in order to use the `FILTER` option of `VSIM` to filter elements by scalars
(see the hybrid search section for more information). This command can set,
update (if already set) or delete (if you set to an empty string) the
associated JSON attributes of an element.
The command returns 0 if the element or the key don't exist, without
raising an error, otherwise 1 is returned, and the element attributes
are set or updated.
**VGETATTR: retrieve the JSON attributes of elements**
VGET key element
The command returns the JSON attribute associated with an element, or
null if there is no element associated, or no element at all, or no key.
# Hybrid search
Each element of the vector set can be associated with a set of attributes specified as a JSON blob:
> VADD vset VALUES 3 1 1 1 a SETATTR '{"year": 1950}'
(integer) 1
> VADD vset VALUES 3 -1 -1 -1 b SETATTR '{"year": 1951}'
(integer) 1
Specifying an attribute with the `SETATTR` option of `VADD` is exactly equivalent to adding an element and then setting (or updating, if already set) the attributes JSON string. Also the symmetrical `VGETATTR` command returns the attribute associated to a given element.
> VAD vset VALUES 3 0 1 0 c
(integer) 1
> VSETATTR vset c '{"year": 1952}'
(integer) 1
> VGETATTR vset c
"{\"year\": 1952}"
At this point, I may use the FILTER option of VSIM to only ask for the subset of elements that are verified by my expression:
> VSIM vset VALUES 3 0 0 0 FILTER '.year > 1950'
1) "c"
2) "b"
The items will be returned again in order of similarity (most similar first), but only the items with the year field matching the expression is returned.
The expressions are similar to what you would write inside the `if` statement of JavaScript or other familiar programming languages: you can use `and`, `or`, the obvious math operators like `+`, `-`, `/`, `>=`, `<`, ... and so forth (see the expressions section for more info). The selectors of the JSON object attributes start with a dot followed by the name of the key inside the JSON objects.
Elements with invalid JSON or not having a given specified field **are considered as not matching** the expression, but will not generate any error at runtime.
I'll draft the missing sections for the README following the style and format of the existing content.
## FILTER expressions capabilities
FILTER expressions allow you to perform complex filtering on vector similarity results using a JavaScript-like syntax. The expression is evaluated against each element's JSON attributes, with only elements that satisfy the expression being included in the results.
### Expression Syntax
Expressions support the following operators and capabilities:
1. **Arithmetic operators**: `+`, `-`, `*`, `/`, `%` (modulo), `**` (exponentiation)
2. **Comparison operators**: `>`, `>=`, `<`, `<=`, `==`, `!=`
3. **Logical operators**: `and`/`&&`, `or`/`||`, `!`/`not`
4. **Containment operator**: `in`
5. **Parentheses** for grouping: `(...)`
### Selector Notation
Attributes are accessed using dot notation:
- `.year` references the "year" attribute
- `.movie.year` would **NOT** reference the "year" field inside a "movie" object, only keys that are at the first level of the JSON object are accessible.
### JSON and expressions data types
Expressions can work with:
- Numbers (dobule precision floats)
- Strings (enclosed in single or double quotes)
- Booleans (no native type: they are represented as 1 for true, 0 for false)
- Arrays (for use with the `in` operator: `value in [1, 2, 3]`)
JSON attributes are converted in this way:
- Numbers will be converted to numbers.
- Strings to strings.
- Booleans to 0 or 1 number.
- Arrays to tuples (for "in" operator), but only if composed of just numbers and strings.
Any other type is ignored, and accessig it will make the expression evaluate to false.
### Examples
```
# Find items from the 1980s
VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.year >= 1980 and .year < 1990'
# Find action movies with high ratings
VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.genre == "action" and .rating > 8.0'
# Find movies directed by either Spielberg or Nolan
VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.director in ["Spielberg", "Nolan"]'
# Complex condition with numerical operations
VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '(.year - 2000) ** 2 < 100 and .rating / 2 > 4'
```
### Error Handling
Elements with any of the following conditions are considered not matching:
- Missing the queried JSON attribute
- Having invalid JSON in their attributes
- Having a JSON value that cannot be converted to the expected type
This behavior allows you to safely filter on optional attributes without generating errors.
### FILTER effort
The `FILTER-EF` option controls the maximum effort spent when filtering vector search results.
When performing vector similarity search with filtering, Vector Sets perform the standard similarity search as they apply the filter expression to each node. Since many results might be filtered out, Vector Sets may need to examine a lot more candidates than the requested `COUNT` to ensure sufficient matching results are returned. Actually, if the elements matching the filter are very rare or if there are less than elements matching than the specified count, this would trigger a full scan of the HNSW graph.
For this reason, by default, the maximum effort is limited to a reasonable amount of nodes explored.
### Modifying the FILTER effort
1. By default, Vector Sets will explore up to `COUNT * 100` candidates to find matching results.
2. You can control this exploration with the `FILTER-EF` parameter.
3. A higher `FILTER-EF` value increases the chances of finding all relevant matches at the cost of increased processing time.
4. A `FILTER-EF` of zero will explore as many nodes as needed in order to actually return the number of elements specified by `COUNT`.
5. Even when a high `FILTER-EF` value is specified **the implementation will do a lot less work** if the elements passing the filter are very common, because of the early stop conditions of the HNSW implementation (once the specified amount of elements is reached and the quality check of the other candidates trigger an early stop).
```
VSIM key [ELE|FP32|VALUES] <vector or element> COUNT 10 FILTER '.year > 2000' FILTER-EF 500
```
In this example, Vector Sets will examine up to 500 potential nodes. Of course if count is reached before exploring 500 nodes, and the quality checks show that it is not possible to make progresses on similarity, the search is ended sooner.
### Performance Considerations
- If you have highly selective filters (few items match), use a higher `FILTER-EF`, or just design your application in order to handle a result set that is smaller than the requested count. Note that anyway the additional elements may be too distant than the query vector.
- For less selective filters, the default should be sufficient.
- Very selective filters with low `FILTER-EF` values may return fewer items than requested.
- Extremely high values may impact performance without significantly improving results.
The optimal `FILTER-EF` value depends on:
1. The selectivity of your filter.
2. The distribution of your data.
3. The required recall quality.
A good practice is to start with the default and increase if needed when you observe fewer results than expected.
### Testing a larg-ish data set
To really see how things work at scale, you can [download](https://antirez.com/word2vec_with_attribs.rdb) the following dataset:
wget https://antirez.com/word2vec_with_attribs.rdb
It contains the 3 million words in Word2Vec having as attribute a JSON with just the length of the word. Because of the length distribution of words in large amounts of texts, where longer words become less and less common, this is ideal to check how filtering behaves with a filter verifying as true with less and less elements in a vector set.
For instance:
> VSIM word_embeddings_bin ele "pasta" FILTER ".len == 6"
1) "pastas"
2) "rotini"
3) "gnocci"
4) "panino"
5) "salads"
6) "breads"
7) "salame"
8) "sauces"
9) "cheese"
10) "fritti"
This will easily retrieve the desired amount of items (`COUNT` is 10 by default) since there are many items of length 6. However:
> VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33"
1) "skinless_boneless_chicken_breasts"
2) "boneless_skinless_chicken_breasts"
3) "Boneless_skinless_chicken_breasts"
This time even if we asked for 10 items, we only get 3, since the default filter effort will be `10*100 = 1000`. We can tune this giving the effort in an explicit way, with the risk of our query being slower, of course:
> VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" FILTER-EF 10000
1) "skinless_boneless_chicken_breasts"
2) "boneless_skinless_chicken_breasts"
3) "Boneless_skinless_chicken_breasts"
4) "mozzarella_feta_provolone_cheddar"
5) "Greatfood.com_R_www.greatfood.com"
6) "Pepperidge_Farm_Goldfish_crackers"
7) "Prosecuted_Mobsters_Rebuilt_Dying"
8) "Crispy_Snacker_Sandwiches_Popcorn"
9) "risultati_delle_partite_disputate"
10) "Peppermint_Mocha_Twist_Gingersnap"
This time we get all the ten items, even if the last one will be quite far from our query vector. We encourage to experiment with this test dataset in order to understand better the dynamics of the implementation and the natural tradeoffs of hybrid search.
**Keep in mind** that by default, Redis Vector Sets will try to avoid a likely very useless huge scan of the HNSW graph, and will be more happy to return few or no elements at all, since this is almost always what the user actually wants in the context of retrieving *similar* items to the query.
# Known bugs
* When VADD with REDUCE is replicated, we should probably send the replicas the random matrix, in order for VEMB to read the same things. This is not critical, because the behavior of VADD / VSIM should be transparent if you don't look at the transformed vectors, but still probably worth doing.
* Replication code is pretty much untested, and very vanilla (replicating the commands verbatim).
## Implementation details
# Implementation details
Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality.
The main features are:
* Proper nodes deletion with relinking.
* 8 bits quantization.
* 8 bits and binary quantization.
* Threaded queries.
* Hybrid search with predicate callback.

3110
cJSON.c Normal file

File diff suppressed because it is too large Load Diff

293
cJSON.h Normal file
View File

@ -0,0 +1,293 @@
/*
Copyright (c) 2009-2017 Dave Gamble and cJSON contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#ifndef cJSON__h
#define cJSON__h
#ifdef __cplusplus
extern "C"
{
#endif
#if !defined(__WINDOWS__) && (defined(WIN32) || defined(WIN64) || defined(_MSC_VER) || defined(_WIN32))
#define __WINDOWS__
#endif
#ifdef __WINDOWS__
/* When compiling for windows, we specify a specific calling convention to avoid issues where we are being called from a project with a different default calling convention. For windows you have 3 define options:
CJSON_HIDE_SYMBOLS - Define this in the case where you don't want to ever dllexport symbols
CJSON_EXPORT_SYMBOLS - Define this on library build when you want to dllexport symbols (default)
CJSON_IMPORT_SYMBOLS - Define this if you want to dllimport symbol
For *nix builds that support visibility attribute, you can define similar behavior by
setting default visibility to hidden by adding
-fvisibility=hidden (for gcc)
or
-xldscope=hidden (for sun cc)
to CFLAGS
then using the CJSON_API_VISIBILITY flag to "export" the same symbols the way CJSON_EXPORT_SYMBOLS does
*/
#define CJSON_CDECL __cdecl
#define CJSON_STDCALL __stdcall
/* export symbols by default, this is necessary for copy pasting the C and header file */
#if !defined(CJSON_HIDE_SYMBOLS) && !defined(CJSON_IMPORT_SYMBOLS) && !defined(CJSON_EXPORT_SYMBOLS)
#define CJSON_EXPORT_SYMBOLS
#endif
#if defined(CJSON_HIDE_SYMBOLS)
#define CJSON_PUBLIC(type) type CJSON_STDCALL
#elif defined(CJSON_EXPORT_SYMBOLS)
#define CJSON_PUBLIC(type) __declspec(dllexport) type CJSON_STDCALL
#elif defined(CJSON_IMPORT_SYMBOLS)
#define CJSON_PUBLIC(type) __declspec(dllimport) type CJSON_STDCALL
#endif
#else /* !__WINDOWS__ */
#define CJSON_CDECL
#define CJSON_STDCALL
#if (defined(__GNUC__) || defined(__SUNPRO_CC) || defined (__SUNPRO_C)) && defined(CJSON_API_VISIBILITY)
#define CJSON_PUBLIC(type) __attribute__((visibility("default"))) type
#else
#define CJSON_PUBLIC(type) type
#endif
#endif
/* project version */
#define CJSON_VERSION_MAJOR 1
#define CJSON_VERSION_MINOR 7
#define CJSON_VERSION_PATCH 14
#include <stddef.h>
/* cJSON Types: */
#define cJSON_Invalid (0)
#define cJSON_False (1 << 0)
#define cJSON_True (1 << 1)
#define cJSON_NULL (1 << 2)
#define cJSON_Number (1 << 3)
#define cJSON_String (1 << 4)
#define cJSON_Array (1 << 5)
#define cJSON_Object (1 << 6)
#define cJSON_Raw (1 << 7) /* raw json */
#define cJSON_IsReference 256
#define cJSON_StringIsConst 512
/* The cJSON structure: */
typedef struct cJSON
{
/* next/prev allow you to walk array/object chains. Alternatively, use GetArraySize/GetArrayItem/GetObjectItem */
struct cJSON *next;
struct cJSON *prev;
/* An array or object item will have a child pointer pointing to a chain of the items in the array/object. */
struct cJSON *child;
/* The type of the item, as above. */
int type;
/* The item's string, if type==cJSON_String and type == cJSON_Raw */
char *valuestring;
/* writing to valueint is DEPRECATED, use cJSON_SetNumberValue instead */
int valueint;
/* The item's number, if type==cJSON_Number */
double valuedouble;
/* The item's name string, if this item is the child of, or is in the list of subitems of an object. */
char *string;
} cJSON;
typedef struct cJSON_Hooks
{
/* malloc/free are CDECL on Windows regardless of the default calling convention of the compiler, so ensure the hooks allow passing those functions directly. */
void *(CJSON_CDECL *malloc_fn)(size_t sz);
void (CJSON_CDECL *free_fn)(void *ptr);
} cJSON_Hooks;
typedef int cJSON_bool;
/* Limits how deeply nested arrays/objects can be before cJSON rejects to parse them.
* This is to prevent stack overflows. */
#ifndef CJSON_NESTING_LIMIT
#define CJSON_NESTING_LIMIT 1000
#endif
/* returns the version of cJSON as a string */
CJSON_PUBLIC(const char*) cJSON_Version(void);
/* Supply malloc, realloc and free functions to cJSON */
CJSON_PUBLIC(void) cJSON_InitHooks(cJSON_Hooks* hooks);
/* Memory Management: the caller is always responsible to free the results from all variants of cJSON_Parse (with cJSON_Delete) and cJSON_Print (with stdlib free, cJSON_Hooks.free_fn, or cJSON_free as appropriate). The exception is cJSON_PrintPreallocated, where the caller has full responsibility of the buffer. */
/* Supply a block of JSON, and this returns a cJSON object you can interrogate. */
CJSON_PUBLIC(cJSON *) cJSON_Parse(const char *value);
CJSON_PUBLIC(cJSON *) cJSON_ParseWithLength(const char *value, size_t buffer_length);
/* ParseWithOpts allows you to require (and check) that the JSON is null terminated, and to retrieve the pointer to the final byte parsed. */
/* If you supply a ptr in return_parse_end and parsing fails, then return_parse_end will contain a pointer to the error so will match cJSON_GetErrorPtr(). */
CJSON_PUBLIC(cJSON *) cJSON_ParseWithOpts(const char *value, const char **return_parse_end, cJSON_bool require_null_terminated);
CJSON_PUBLIC(cJSON *) cJSON_ParseWithLengthOpts(const char *value, size_t buffer_length, const char **return_parse_end, cJSON_bool require_null_terminated);
/* Render a cJSON entity to text for transfer/storage. */
CJSON_PUBLIC(char *) cJSON_Print(const cJSON *item);
/* Render a cJSON entity to text for transfer/storage without any formatting. */
CJSON_PUBLIC(char *) cJSON_PrintUnformatted(const cJSON *item);
/* Render a cJSON entity to text using a buffered strategy. prebuffer is a guess at the final size. guessing well reduces reallocation. fmt=0 gives unformatted, =1 gives formatted */
CJSON_PUBLIC(char *) cJSON_PrintBuffered(const cJSON *item, int prebuffer, cJSON_bool fmt);
/* Render a cJSON entity to text using a buffer already allocated in memory with given length. Returns 1 on success and 0 on failure. */
/* NOTE: cJSON is not always 100% accurate in estimating how much memory it will use, so to be safe allocate 5 bytes more than you actually need */
CJSON_PUBLIC(cJSON_bool) cJSON_PrintPreallocated(cJSON *item, char *buffer, const int length, const cJSON_bool format);
/* Delete a cJSON entity and all subentities. */
CJSON_PUBLIC(void) cJSON_Delete(cJSON *item);
/* Returns the number of items in an array (or object). */
CJSON_PUBLIC(int) cJSON_GetArraySize(const cJSON *array);
/* Retrieve item number "index" from array "array". Returns NULL if unsuccessful. */
CJSON_PUBLIC(cJSON *) cJSON_GetArrayItem(const cJSON *array, int index);
/* Get item "string" from object. Case insensitive. */
CJSON_PUBLIC(cJSON *) cJSON_GetObjectItem(const cJSON * const object, const char * const string);
CJSON_PUBLIC(cJSON *) cJSON_GetObjectItemCaseSensitive(const cJSON * const object, const char * const string);
CJSON_PUBLIC(cJSON_bool) cJSON_HasObjectItem(const cJSON *object, const char *string);
/* For analysing failed parses. This returns a pointer to the parse error. You'll probably need to look a few chars back to make sense of it. Defined when cJSON_Parse() returns 0. 0 when cJSON_Parse() succeeds. */
CJSON_PUBLIC(const char *) cJSON_GetErrorPtr(void);
/* Check item type and return its value */
CJSON_PUBLIC(char *) cJSON_GetStringValue(const cJSON * const item);
CJSON_PUBLIC(double) cJSON_GetNumberValue(const cJSON * const item);
/* These functions check the type of an item */
CJSON_PUBLIC(cJSON_bool) cJSON_IsInvalid(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsFalse(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsTrue(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsBool(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsNull(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsNumber(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsString(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsArray(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsObject(const cJSON * const item);
CJSON_PUBLIC(cJSON_bool) cJSON_IsRaw(const cJSON * const item);
/* These calls create a cJSON item of the appropriate type. */
CJSON_PUBLIC(cJSON *) cJSON_CreateNull(void);
CJSON_PUBLIC(cJSON *) cJSON_CreateTrue(void);
CJSON_PUBLIC(cJSON *) cJSON_CreateFalse(void);
CJSON_PUBLIC(cJSON *) cJSON_CreateBool(cJSON_bool boolean);
CJSON_PUBLIC(cJSON *) cJSON_CreateNumber(double num);
CJSON_PUBLIC(cJSON *) cJSON_CreateString(const char *string);
/* raw json */
CJSON_PUBLIC(cJSON *) cJSON_CreateRaw(const char *raw);
CJSON_PUBLIC(cJSON *) cJSON_CreateArray(void);
CJSON_PUBLIC(cJSON *) cJSON_CreateObject(void);
/* Create a string where valuestring references a string so
* it will not be freed by cJSON_Delete */
CJSON_PUBLIC(cJSON *) cJSON_CreateStringReference(const char *string);
/* Create an object/array that only references it's elements so
* they will not be freed by cJSON_Delete */
CJSON_PUBLIC(cJSON *) cJSON_CreateObjectReference(const cJSON *child);
CJSON_PUBLIC(cJSON *) cJSON_CreateArrayReference(const cJSON *child);
/* These utilities create an Array of count items.
* The parameter count cannot be greater than the number of elements in the number array, otherwise array access will be out of bounds.*/
CJSON_PUBLIC(cJSON *) cJSON_CreateIntArray(const int *numbers, int count);
CJSON_PUBLIC(cJSON *) cJSON_CreateFloatArray(const float *numbers, int count);
CJSON_PUBLIC(cJSON *) cJSON_CreateDoubleArray(const double *numbers, int count);
CJSON_PUBLIC(cJSON *) cJSON_CreateStringArray(const char *const *strings, int count);
/* Append item to the specified array/object. */
CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToArray(cJSON *array, cJSON *item);
CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObject(cJSON *object, const char *string, cJSON *item);
/* Use this when string is definitely const (i.e. a literal, or as good as), and will definitely survive the cJSON object.
* WARNING: When this function was used, make sure to always check that (item->type & cJSON_StringIsConst) is zero before
* writing to `item->string` */
CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObjectCS(cJSON *object, const char *string, cJSON *item);
/* Append reference to item to the specified array/object. Use this when you want to add an existing cJSON to a new cJSON, but don't want to corrupt your existing cJSON. */
CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item);
CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToObject(cJSON *object, const char *string, cJSON *item);
/* Remove/Detach items from Arrays/Objects. */
CJSON_PUBLIC(cJSON *) cJSON_DetachItemViaPointer(cJSON *parent, cJSON * const item);
CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromArray(cJSON *array, int which);
CJSON_PUBLIC(void) cJSON_DeleteItemFromArray(cJSON *array, int which);
CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObject(cJSON *object, const char *string);
CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObjectCaseSensitive(cJSON *object, const char *string);
CJSON_PUBLIC(void) cJSON_DeleteItemFromObject(cJSON *object, const char *string);
CJSON_PUBLIC(void) cJSON_DeleteItemFromObjectCaseSensitive(cJSON *object, const char *string);
/* Update array items. */
CJSON_PUBLIC(cJSON_bool) cJSON_InsertItemInArray(cJSON *array, int which, cJSON *newitem); /* Shifts pre-existing items to the right. */
CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemViaPointer(cJSON * const parent, cJSON * const item, cJSON * replacement);
CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInArray(cJSON *array, int which, cJSON *newitem);
CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObject(cJSON *object,const char *string,cJSON *newitem);
CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObjectCaseSensitive(cJSON *object,const char *string,cJSON *newitem);
/* Duplicate a cJSON item */
CJSON_PUBLIC(cJSON *) cJSON_Duplicate(const cJSON *item, cJSON_bool recurse);
/* Duplicate will create a new, identical cJSON item to the one you pass, in new memory that will
* need to be released. With recurse!=0, it will duplicate any children connected to the item.
* The item->next and ->prev pointers are always zero on return from Duplicate. */
/* Recursively compare two cJSON items for equality. If either a or b is NULL or invalid, they will be considered unequal.
* case_sensitive determines if object keys are treated case sensitive (1) or case insensitive (0) */
CJSON_PUBLIC(cJSON_bool) cJSON_Compare(const cJSON * const a, const cJSON * const b, const cJSON_bool case_sensitive);
/* Minify a strings, remove blank characters(such as ' ', '\t', '\r', '\n') from strings.
* The input pointer json cannot point to a read-only address area, such as a string constant,
* but should point to a readable and writable adress area. */
CJSON_PUBLIC(void) cJSON_Minify(char *json);
/* Helper functions for creating and adding items to an object at the same time.
* They return the added item or NULL on failure. */
CJSON_PUBLIC(cJSON*) cJSON_AddNullToObject(cJSON * const object, const char * const name);
CJSON_PUBLIC(cJSON*) cJSON_AddTrueToObject(cJSON * const object, const char * const name);
CJSON_PUBLIC(cJSON*) cJSON_AddFalseToObject(cJSON * const object, const char * const name);
CJSON_PUBLIC(cJSON*) cJSON_AddBoolToObject(cJSON * const object, const char * const name, const cJSON_bool boolean);
CJSON_PUBLIC(cJSON*) cJSON_AddNumberToObject(cJSON * const object, const char * const name, const double number);
CJSON_PUBLIC(cJSON*) cJSON_AddStringToObject(cJSON * const object, const char * const name, const char * const string);
CJSON_PUBLIC(cJSON*) cJSON_AddRawToObject(cJSON * const object, const char * const name, const char * const raw);
CJSON_PUBLIC(cJSON*) cJSON_AddObjectToObject(cJSON * const object, const char * const name);
CJSON_PUBLIC(cJSON*) cJSON_AddArrayToObject(cJSON * const object, const char * const name);
/* When assigning an integer value, it needs to be propagated to valuedouble too. */
#define cJSON_SetIntValue(object, number) ((object) ? (object)->valueint = (object)->valuedouble = (number) : (number))
/* helper for the cJSON_SetNumberValue macro */
CJSON_PUBLIC(double) cJSON_SetNumberHelper(cJSON *object, double number);
#define cJSON_SetNumberValue(object, number) ((object != NULL) ? cJSON_SetNumberHelper(object, (double)number) : (number))
/* Change the valuestring of a cJSON_String object, only takes effect when type of object is cJSON_String */
CJSON_PUBLIC(char*) cJSON_SetValuestring(cJSON *object, const char *valuestring);
/* Macro for iterating over an array or object */
#define cJSON_ArrayForEach(element, array) for(element = (array != NULL) ? (array)->child : NULL; element != NULL; element = element->next)
/* malloc/free objects using the malloc/free functions that have been set with cJSON_InitHooks */
CJSON_PUBLIC(void *) cJSON_malloc(size_t size);
CJSON_PUBLIC(void) cJSON_free(void *object);
#ifdef __cplusplus
}
#endif
#endif

994
expr.c Normal file
View File

@ -0,0 +1,994 @@
/* Filtering of objects based on simple expressions.
* This powers the FILTER option of Vector Sets, but it is otherwise
* general code to be used when we want to tell if a given object (with fields)
* passes or fails a given test for scalars, strings, ...
*
* Copyright(C) 2024-2025 Salvatore Sanfilippo. All Rights Reserved.
*/
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <math.h>
#include <string.h>
#include "cJSON.h"
#ifdef TEST_MAIN
#define RedisModule_Alloc malloc
#define RedisModule_Realloc realloc
#define RedisModule_Free free
#define RedisModule_Strdup strdup
#endif
#define EXPR_TOKEN_EOF 0
#define EXPR_TOKEN_NUM 1
#define EXPR_TOKEN_STR 2
#define EXPR_TOKEN_TUPLE 3
#define EXPR_TOKEN_SELECTOR 4
#define EXPR_TOKEN_OP 5
#define EXPR_OP_OPAREN 0 /* ( */
#define EXPR_OP_CPAREN 1 /* ) */
#define EXPR_OP_NOT 2 /* ! */
#define EXPR_OP_POW 3 /* ** */
#define EXPR_OP_MULT 4 /* * */
#define EXPR_OP_DIV 5 /* / */
#define EXPR_OP_MOD 6 /* % */
#define EXPR_OP_SUM 7 /* + */
#define EXPR_OP_DIFF 8 /* - */
#define EXPR_OP_GT 9 /* > */
#define EXPR_OP_GTE 10 /* >= */
#define EXPR_OP_LT 11 /* < */
#define EXPR_OP_LTE 12 /* <= */
#define EXPR_OP_EQ 13 /* == */
#define EXPR_OP_NEQ 14 /* != */
#define EXPR_OP_IN 15 /* in */
#define EXPR_OP_AND 16 /* and */
#define EXPR_OP_OR 17 /* or */
/* This structure represents a token in our expression. It's either
* literals like 4, "foo", or operators like "+", "-", "and", or
* json selectors, that start with a dot: ".age", ".properties.somearray[1]" */
typedef struct exprtoken {
int refcount; // Reference counting for memory reclaiming.
int token_type; // Token type of the just parsed token.
int offset; // Chars offset in expression.
union {
double num; // Value for EXPR_TOKEN_NUM.
struct {
char *start; // String pointer for EXPR_TOKEN_STR / SELECTOR.
size_t len; // String len for EXPR_TOKEN_STR / SELECTOR.
char *heapstr; // True if we have a private allocation for this
// string. When possible, it just references to the
// string expression we compiled, exprstate->expr.
} str;
int opcode; // Opcode ID for EXPR_TOKEN_OP.
struct {
struct exprtoken **ele;
size_t len;
} tuple; // Tuples are like [1, 2, 3] for "in" operator.
};
} exprtoken;
/* Simple stack of expr tokens. This is used both to represent the stack
* of values and the stack of operands during VM execution. */
typedef struct exprstack {
exprtoken **items;
int numitems;
int allocsize;
} exprstack;
typedef struct exprstate {
char *expr; /* Expression string to compile. Note that
* expression token strings point directly to this
* string. */
char *p; // Currnet position inside 'expr', while parsing.
// Virtual machine state.
exprstack values_stack;
exprstack ops_stack; // Operator stack used during compilation.
exprstack tokens; // Expression processed into a sequence of tokens.
exprstack program; // Expression compiled into opcodes and values.
} exprstate;
/* Valid operators. */
struct {
char *opname;
int oplen;
int opcode;
int precedence;
int arity;
} ExprOptable[] = {
{"(", 1, EXPR_OP_OPAREN, 7, 0},
{")", 1, EXPR_OP_CPAREN, 7, 0},
{"!", 1, EXPR_OP_NOT, 6, 1},
{"not", 3, EXPR_OP_NOT, 6, 1},
{"**", 2, EXPR_OP_POW, 5, 2},
{"*", 1, EXPR_OP_MULT, 4, 2},
{"/", 1, EXPR_OP_DIV, 4, 2},
{"%", 1, EXPR_OP_MOD, 4, 2},
{"+", 1, EXPR_OP_SUM, 3, 2},
{"-", 1, EXPR_OP_DIFF, 3, 2},
{">", 1, EXPR_OP_GT, 2, 2},
{">=", 2, EXPR_OP_GTE, 2, 2},
{"<", 1, EXPR_OP_LT, 2, 2},
{"<=", 2, EXPR_OP_LTE, 2, 2},
{"==", 2, EXPR_OP_EQ, 2, 2},
{"!=", 2, EXPR_OP_NEQ, 2, 2},
{"in", 2, EXPR_OP_IN, 2, 2},
{"and", 3, EXPR_OP_AND, 1, 2},
{"&&", 2, EXPR_OP_AND, 1, 2},
{"or", 2, EXPR_OP_OR, 0, 2},
{"||", 2, EXPR_OP_OR, 0, 2},
{NULL, 0, 0, 0, 0} // Terminator.
};
#define EXPR_OP_SPECIALCHARS "+-*%/!()<>=|&"
#define EXPR_SELECTOR_SPECIALCHARS "_-"
/* ================================ Expr token ============================== */
/* Return an heap allocated token of the specified type, setting the
* reference count to 1. */
exprtoken *exprNewToken(int type) {
exprtoken *t = RedisModule_Alloc(sizeof(exprtoken));
memset(t,0,sizeof(*t));
t->token_type = type;
t->refcount = 1;
return t;
}
/* Generic free token function, can be used to free stack allocated
* objects (in this case the pointer itself will not be freed) or
* heap allocated objects. See the wrappers below. */
void exprTokenRelease(exprtoken *t) {
if (t == NULL) return;
if (t->refcount <= 0) {
printf("exprTokenRelease() against a token with refcount %d!\n"
"Aborting program execution\n",
t->refcount);
exit(1);
}
t->refcount--;
if (t->refcount > 0) return;
// We reached refcount 0: free the object.
if (t->token_type == EXPR_TOKEN_STR) {
if (t->str.heapstr != NULL) RedisModule_Free(t->str.heapstr);
} else if (t->token_type == EXPR_TOKEN_TUPLE) {
for (size_t j = 0; j < t->tuple.len; j++)
exprTokenRelease(t->tuple.ele[j]);
if (t->tuple.ele) RedisModule_Free(t->tuple.ele);
}
RedisModule_Free(t);
}
void exprTokenRetain(exprtoken *t) {
t->refcount++;
}
/* ============================== Stack handling ============================ */
#include <stdlib.h>
#include <string.h>
#define EXPR_STACK_INITIAL_SIZE 16
/* Initialize a new expression stack. */
void exprStackInit(exprstack *stack) {
stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
stack->numitems = 0;
stack->allocsize = EXPR_STACK_INITIAL_SIZE;
}
/* Push a token pointer onto the stack. Does not increment the refcount
* of the token: it is up to the caller doing this. */
void exprStackPush(exprstack *stack, exprtoken *token) {
/* Check if we need to grow the stack. */
if (stack->numitems == stack->allocsize) {
size_t newsize = stack->allocsize * 2;
exprtoken **newitems =
RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize);
stack->items = newitems;
stack->allocsize = newsize;
}
stack->items[stack->numitems] = token;
stack->numitems++;
}
/* Pop a token pointer from the stack. Return NULL if the stack is
* empty. Does NOT recrement the refcount of the token, it's up to the
* caller to do so, as the new owner of the reference. */
exprtoken *exprStackPop(exprstack *stack) {
if (stack->numitems == 0) return NULL;
stack->numitems--;
return stack->items[stack->numitems];
}
/* Just return the last element pushed, without consuming it nor altering
* the reference count. */
exprtoken *exprStackPeek(exprstack *stack) {
if (stack->numitems == 0) return NULL;
return stack->items[stack->numitems-1];
}
/* Free the stack structure state, including the items it contains, that are
* assumed to be heap allocated. The passed pointer itself is not freed. */
void exprStackFree(exprstack *stack) {
for (int j = 0; j < stack->numitems; j++)
exprTokenRelease(stack->items[j]);
RedisModule_Free(stack->items);
}
/* Just reset the stack removing all the items, but leaving it in a state
* that makes it still usable for new elements. */
void exprStackReset(exprstack *stack) {
for (int j = 0; j < stack->numitems; j++)
exprTokenRelease(stack->items[j]);
stack->numitems = 0;
}
/* =========================== Expression compilation ======================= */
void exprConsumeSpaces(exprstate *es) {
while(es->p[0] && isspace(es->p[0])) es->p++;
}
/* Parse an operator, trying to match the longer match in the
* operators table. */
exprtoken *exprParseOperator(exprstate *es) {
exprtoken *t = exprNewToken(EXPR_TOKEN_OP);
char *start = es->p;
while(es->p[0] &&
(isalpha(es->p[0]) ||
strchr(EXPR_OP_SPECIALCHARS,es->p[0]) != NULL))
{
es->p++;
}
int matchlen = es->p - start;
int bestlen = 0;
int j;
// Find the longest matching operator.
for (j = 0; ExprOptable[j].opname != NULL; j++) {
if (ExprOptable[j].oplen > matchlen) continue;
if (memcmp(ExprOptable[j].opname, start, ExprOptable[j].oplen) != 0)
{
continue;
}
if (ExprOptable[j].oplen > bestlen) {
t->opcode = ExprOptable[j].opcode;
bestlen = ExprOptable[j].oplen;
}
}
if (bestlen == 0) {
exprTokenRelease(t);
return NULL;
} else {
es->p = start + bestlen;
}
return t;
}
// Valid selector charset.
static int is_selector_char(int c) {
return (isalpha(c) ||
isdigit(c) ||
strchr(EXPR_SELECTOR_SPECIALCHARS,c) != NULL);
}
/* Parse selectors, they start with a dot and can have alphanumerical
* or few special chars. */
exprtoken *exprParseSelector(exprstate *es) {
exprtoken *t = exprNewToken(EXPR_TOKEN_SELECTOR);
es->p++; // Skip dot.
char *start = es->p;
while(es->p[0] && is_selector_char(es->p[0])) es->p++;
int matchlen = es->p - start;
t->str.start = start;
t->str.len = matchlen;
return t;
}
exprtoken *exprParseNumber(exprstate *es) {
exprtoken *t = exprNewToken(EXPR_TOKEN_NUM);
char num[64];
int idx = 0;
while(isdigit(es->p[0]) || es->p[0] == '.' || es->p[0] == 'e' ||
es->p[0] == 'E' || (idx == 0 && es->p[0] == '-'))
{
if (idx >= (int)sizeof(num)-1) {
exprTokenRelease(t);
return NULL;
}
num[idx++] = es->p[0];
es->p++;
}
num[idx] = 0;
char *endptr;
t->num = strtod(num, &endptr);
if (*endptr != '\0') {
exprTokenRelease(t);
return NULL;
}
return t;
}
exprtoken *exprParseString(exprstate *es) {
char quote = es->p[0]; /* Store the quote type (' or "). */
es->p++; /* Skip opening quote. */
exprtoken *t = exprNewToken(EXPR_TOKEN_STR);
t->str.start = es->p;
while(es->p[0] != '\0') {
if (es->p[0] == '\\' && es->p[1] != '\0') {
es->p += 2; // Skip escaped char.
continue;
}
if (es->p[0] == quote) {
t->str.len = es->p - t->str.start;
es->p++; // Skip closing quote.
return t;
}
es->p++;
}
/* If we reach here, string was not terminated. */
exprTokenRelease(t);
return NULL;
}
/* Parse a tuple of the form [1, "foo", 42]. No nested tuples are
* supported. This type is useful mostly to be used with the "IN"
* operator. */
exprtoken *exprParseTuple(exprstate *es) {
exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE);
t->tuple.ele = NULL;
t->tuple.len = 0;
es->p++; /* Skip opening '['. */
size_t allocated = 0;
while(1) {
exprConsumeSpaces(es);
/* Check for empty tuple or end. */
if (es->p[0] == ']') {
es->p++;
break;
}
/* Grow tuple array if needed. */
if (t->tuple.len == allocated) {
size_t newsize = allocated == 0 ? 4 : allocated * 2;
exprtoken **newele = RedisModule_Realloc(t->tuple.ele,
sizeof(exprtoken*) * newsize);
t->tuple.ele = newele;
allocated = newsize;
}
/* Parse tuple element. */
exprtoken *ele = NULL;
if (isdigit(es->p[0]) || es->p[0] == '-') {
ele = exprParseNumber(es);
} else if (es->p[0] == '"' || es->p[0] == '\'') {
ele = exprParseString(es);
} else {
exprTokenRelease(t);
return NULL;
}
/* Error parsing number/string? */
if (ele == NULL) {
exprTokenRelease(t);
return NULL;
}
/* Store element if no error was detected. */
t->tuple.ele[t->tuple.len] = ele;
t->tuple.len++;
/* Check for next element. */
exprConsumeSpaces(es);
if (es->p[0] == ']') {
es->p++;
break;
}
if (es->p[0] != ',') {
exprTokenRelease(t);
return NULL;
}
es->p++; /* Skip comma. */
}
return t;
}
/* Deallocate the object returned by exprCompile(). */
void exprFree(exprstate *es) {
if (es == NULL) return;
/* Free the original expression string. */
if (es->expr) RedisModule_Free(es->expr);
/* Free all stacks. */
exprStackFree(&es->values_stack);
exprStackFree(&es->ops_stack);
exprStackFree(&es->tokens);
exprStackFree(&es->program);
/* Free the state object itself. */
RedisModule_Free(es);
}
/* Split the provided expression into a stack of tokens. Returns
* 0 on success, 1 on error. */
int exprTokenize(exprstate *es, int *errpos) {
/* Main parsing loop. */
while(1) {
exprConsumeSpaces(es);
/* Set a flag to see if we can consider the - part of the
* number, or an operator. */
int minus_is_number = 0; // By default is an operator.
exprtoken *last = exprStackPeek(&es->tokens);
if (last == NULL) {
/* If we are at the start of an expression, the minus is
* considered a number. */
minus_is_number = 1;
} else if (last->token_type == EXPR_TOKEN_OP &&
last->opcode != EXPR_OP_CPAREN)
{
/* Also, if the previous token was an operator, the minus
* is considered a number, unless the previous operator is
* a closing parens. In such case it's like (...) -5, or alike
* and we want to emit an operator. */
minus_is_number = 1;
}
/* Parse based on the current character. */
exprtoken *current = NULL;
if (*es->p == '\0') {
current = exprNewToken(EXPR_TOKEN_EOF);
} else if (isdigit(*es->p) ||
(minus_is_number && *es->p == '-' && isdigit(es->p[1])))
{
current = exprParseNumber(es);
} else if (*es->p == '"' || *es->p == '\'') {
current = exprParseString(es);
} else if (*es->p == '.' && is_selector_char(es->p[1])) {
current = exprParseSelector(es);
} else if (isalpha(*es->p) || strchr(EXPR_OP_SPECIALCHARS, *es->p)) {
current = exprParseOperator(es);
} else if (*es->p == '[') {
current = exprParseTuple(es);
}
if (current == NULL) {
if (errpos) *errpos = es->p - es->expr;
return 1; // Syntax Error.
}
/* Push the current token to tokens stack. */
exprStackPush(&es->tokens, current);
if (current->token_type == EXPR_TOKEN_EOF) break;
}
return 0;
}
/* Helper function to get operator precedence from the operator table. */
int exprGetOpPrecedence(int opcode) {
for (int i = 0; ExprOptable[i].opname != NULL; i++) {
if (ExprOptable[i].opcode == opcode)
return ExprOptable[i].precedence;
}
return -1;
}
/* Helper function to get operator arity from the operator table. */
int exprGetOpArity(int opcode) {
for (int i = 0; ExprOptable[i].opname != NULL; i++) {
if (ExprOptable[i].opcode == opcode)
return ExprOptable[i].arity;
}
return -1;
}
/* Process an operator during compilation. Returns 0 on success, 1 on error.
* This function will retain a reference of the operator 'op' in case it
* is pushed on the operators stack. */
int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *errpos) {
if (op->opcode == EXPR_OP_OPAREN) {
// This is just a marker for us. Do nothing.
exprStackPush(&es->ops_stack, op);
exprTokenRetain(op);
return 0;
}
if (op->opcode == EXPR_OP_CPAREN) {
/* Process operators until we find the matching opening parenthesis. */
while (1) {
exprtoken *top_op = exprStackPop(&es->ops_stack);
if (top_op == NULL) {
if (errpos) *errpos = op->offset;
return 1;
}
if (top_op->opcode == EXPR_OP_OPAREN) {
/* Open parethesis found. Our work finished. */
exprTokenRelease(top_op);
return 0;
}
int arity = exprGetOpArity(top_op->opcode);
if (*stack_items < arity) {
exprTokenRelease(top_op);
if (errpos) *errpos = top_op->offset;
return 1;
}
/* Move the operator on the program stack. */
exprStackPush(&es->program, top_op);
*stack_items = *stack_items - arity + 1;
}
}
int curr_prec = exprGetOpPrecedence(op->opcode);
/* Process operators with higher or equal precedence. */
while (1) {
exprtoken *top_op = exprStackPeek(&es->ops_stack);
if (top_op == NULL || top_op->opcode == EXPR_OP_OPAREN) break;
int top_prec = exprGetOpPrecedence(top_op->opcode);
if (top_prec < curr_prec) break;
/* Special case for **: only pop if precedence is strictly higher
* so that the operator is right associative, that is:
* 2 ** 3 ** 2 is evaluated as 2 ** (3 ** 2) == 512 instead
* of (2 ** 3) ** 2 == 64. */
if (op->opcode == EXPR_OP_POW && top_prec <= curr_prec) break;
/* Pop and add to program. */
top_op = exprStackPop(&es->ops_stack);
int arity = exprGetOpArity(top_op->opcode);
if (*stack_items < arity) {
exprTokenRelease(top_op);
if (errpos) *errpos = top_op->offset;
return 1;
}
/* Move to the program stack. */
exprStackPush(&es->program, top_op);
*stack_items = *stack_items - arity + 1;
}
/* Push current operator. */
exprStackPush(&es->ops_stack, op);
exprTokenRetain(op);
return 0;
}
/* Compile the expression into a set of push-value and exec-operator
* that exprRun() can execute. The function returns an expstate object
* that can be used for execution of the program. On error, NULL
* is returned, and optionally the position of the error into the
* expression is returned by reference. */
exprstate *exprCompile(char *expr, int *errpos) {
/* Initialize expression state. */
exprstate *es = RedisModule_Alloc(sizeof(exprstate));
es->expr = RedisModule_Strdup(expr);
es->p = es->expr;
/* Initialize all stacks. */
exprStackInit(&es->values_stack);
exprStackInit(&es->ops_stack);
exprStackInit(&es->tokens);
exprStackInit(&es->program);
/* Tokenization. */
if (exprTokenize(es, errpos)) {
exprFree(es);
return NULL;
}
/* Compile the expression into a sequence of operations. */
int stack_items = 0; // Track # of items that would be on the stack
// during execution. This way we can detect arity
// issues at compile time.
/* Process each token. */
for (int i = 0; i < es->tokens.numitems; i++) {
exprtoken *token = es->tokens.items[i];
if (token->token_type == EXPR_TOKEN_EOF) break;
/* Handle values (numbers, strings, selectors). */
if (token->token_type == EXPR_TOKEN_NUM ||
token->token_type == EXPR_TOKEN_STR ||
token->token_type == EXPR_TOKEN_TUPLE ||
token->token_type == EXPR_TOKEN_SELECTOR)
{
exprStackPush(&es->program, token);
exprTokenRetain(token);
stack_items++;
continue;
}
/* Handle operators. */
if (token->token_type == EXPR_TOKEN_OP) {
if (exprProcessOperator(es, token, &stack_items, errpos)) {
exprFree(es);
return NULL;
}
continue;
}
}
/* Process remaining operators on the stack. */
while (es->ops_stack.numitems > 0) {
exprtoken *op = exprStackPop(&es->ops_stack);
if (op->opcode == EXPR_OP_OPAREN) {
if (errpos) *errpos = op->offset;
exprTokenRelease(op);
exprFree(es);
return NULL;
}
int arity = exprGetOpArity(op->opcode);
if (stack_items < arity) {
if (errpos) *errpos = op->offset;
exprTokenRelease(op);
exprFree(es);
return NULL;
}
exprStackPush(&es->program, op);
stack_items = stack_items - arity + 1;
}
/* Verify that exactly one value would remain on the stack after
* execution. We could also check that such value is a number, but this
* would make the code more complex without much gains. */
if (stack_items != 1) {
if (errpos) {
/* Point to the last token's offset for error reporting. */
exprtoken *last = es->tokens.items[es->tokens.numitems - 1];
*errpos = last->offset;
}
exprFree(es);
return NULL;
}
return es;
}
/* ============================ Expression execution ======================== */
/* Convert a token to its numeric value. For strings we attempt to parse them
* as numbers, returning 0 if conversion fails. */
double exprTokenToNum(exprtoken *t) {
char buf[128];
if (t->token_type == EXPR_TOKEN_NUM) {
return t->num;
} else if (t->token_type == EXPR_TOKEN_STR && t->str.len < sizeof(buf)) {
memcpy(buf, t->str.start, t->str.len);
buf[t->str.len] = '\0';
char *endptr;
double val = strtod(buf, &endptr);
return *endptr == '\0' ? val : 0;
} else {
return 0;
}
}
/* Conver obejct to true/false (0 or 1) */
double exprTokenToBool(exprtoken *t) {
if (t->token_type == EXPR_TOKEN_NUM) {
return t->num != 0;
} else if (t->token_type == EXPR_TOKEN_STR && t->str.len == 0) {
return 0; // Empty string are false, like in Javascript.
} else {
return 1; // Every non numerical type is true.
}
}
/* Compare two tokens. Returns true if they are equal. */
int exprTokensEqual(exprtoken *a, exprtoken *b) {
// If both are strings, do string comparison.
if (a->token_type == EXPR_TOKEN_STR && b->token_type == EXPR_TOKEN_STR) {
return a->str.len == b->str.len &&
memcmp(a->str.start, b->str.start, a->str.len) == 0;
}
// If both are numbers, do numeric comparison.
if (a->token_type == EXPR_TOKEN_NUM && b->token_type == EXPR_TOKEN_NUM) {
return a->num == b->num;
}
// Mixed types - convert to numbers and compare.
return exprTokenToNum(a) == exprTokenToNum(b);
}
/* Convert a json object to an expression token. There is only
* limited support for JSON arrays: they must be composed of
* just numbers and strings. Returns NULL if the JSON object
* cannot be converted. */
exprtoken *exprJsonToToken(cJSON *js) {
if (cJSON_IsNumber(js)) {
exprtoken *obj = exprNewToken(EXPR_TOKEN_NUM);
obj->num = cJSON_GetNumberValue(js);
return obj;
} else if (cJSON_IsString(js)) {
exprtoken *obj = exprNewToken(EXPR_TOKEN_STR);
char *strval = cJSON_GetStringValue(js);
obj->str.heapstr = RedisModule_Strdup(strval);
obj->str.start = obj->str.heapstr;
obj->str.len = strlen(obj->str.heapstr);
return obj;
} else if (cJSON_IsBool(js)) {
exprtoken *obj = exprNewToken(EXPR_TOKEN_NUM);
obj->num = cJSON_IsTrue(js);
return obj;
} else if (cJSON_IsArray(js)) {
// First, scan the array to ensure it only
// contains strings and numbers. Otherwise the
// expression will evaluate to false.
int array_size = cJSON_GetArraySize(js);
for (int j = 0; j < array_size; j++) {
cJSON *item = cJSON_GetArrayItem(js, j);
if (!cJSON_IsNumber(item) && !cJSON_IsString(item)) return NULL;
}
// Create a tuple token for the array.
exprtoken *obj = exprNewToken(EXPR_TOKEN_TUPLE);
obj->tuple.len = array_size;
obj->tuple.ele = NULL;
if (obj->tuple.len == 0) return obj; // No elements, already ok.
obj->tuple.ele =
RedisModule_Alloc(sizeof(exprtoken*) * obj->tuple.len);
// Convert each array element to a token.
for (size_t j = 0; j < obj->tuple.len; j++) {
cJSON *item = cJSON_GetArrayItem(js, j);
if (cJSON_IsNumber(item)) {
exprtoken *eleToken = exprNewToken(EXPR_TOKEN_NUM);
eleToken->num = cJSON_GetNumberValue(item);
obj->tuple.ele[j] = eleToken;
} else if (cJSON_IsString(item)) {
exprtoken *eleToken = exprNewToken(EXPR_TOKEN_STR);
char *strval = cJSON_GetStringValue(item);
eleToken->str.heapstr = RedisModule_Strdup(strval);
eleToken->str.start = eleToken->str.heapstr;
eleToken->str.len = strlen(eleToken->str.heapstr);
obj->tuple.ele[j] = eleToken;
}
}
return obj;
}
return NULL; // No conversion possible for this type.
}
/* Execute the compiled expression program. Returns 1 if the final stack value
* evaluates to true, 0 otherwise. Also returns 0 if any selector callback
* fails. */
int exprRun(exprstate *es, char *json, size_t json_len) {
exprStackReset(&es->values_stack);
cJSON *parsed_json = NULL;
// Execute each instruction in the program.
for (int i = 0; i < es->program.numitems; i++) {
exprtoken *t = es->program.items[i];
// Handle selectors by calling the callback.
if (t->token_type == EXPR_TOKEN_SELECTOR) {
if (json != NULL) {
cJSON *attrib = NULL;
if (parsed_json == NULL) {
parsed_json = cJSON_ParseWithLength(json,json_len);
// Will be left to NULL if the above fails.
}
if (parsed_json) {
char item_name[128];
if (t->str.len > 0 && t->str.len < sizeof(item_name)) {
memcpy(item_name,t->str.start,t->str.len);
item_name[t->str.len] = 0;
attrib = cJSON_GetObjectItem(parsed_json,item_name);
}
/* Fill the token according to the JSON type stored
* at the attribute. */
if (attrib) {
exprtoken *obj = exprJsonToToken(attrib);
if (obj) {
exprStackPush(&es->values_stack, obj);
continue;
}
}
}
}
// Selector not found or JSON object not convertible to
// expression tokens. Evaluate the expression to false.
if (parsed_json) cJSON_Delete(parsed_json);
return 0;
}
// Push non-operator values directly onto the stack.
if (t->token_type != EXPR_TOKEN_OP) {
exprStackPush(&es->values_stack, t);
exprTokenRetain(t);
continue;
}
// Handle operators.
exprtoken *result = exprNewToken(EXPR_TOKEN_NUM);
// Pop operands - we know we have enough from compile-time checks.
exprtoken *b = exprStackPop(&es->values_stack);
exprtoken *a = NULL;
if (exprGetOpArity(t->opcode) == 2) {
a = exprStackPop(&es->values_stack);
}
switch(t->opcode) {
case EXPR_OP_NOT:
result->num = exprTokenToBool(b) == 0 ? 1 : 0;
break;
case EXPR_OP_POW: {
double base = exprTokenToNum(a);
double exp = exprTokenToNum(b);
result->num = pow(base, exp);
break;
}
case EXPR_OP_MULT:
result->num = exprTokenToNum(a) * exprTokenToNum(b);
break;
case EXPR_OP_DIV:
result->num = exprTokenToNum(a) / exprTokenToNum(b);
break;
case EXPR_OP_MOD: {
double va = exprTokenToNum(a);
double vb = exprTokenToNum(b);
result->num = fmod(va, vb);
break;
}
case EXPR_OP_SUM:
result->num = exprTokenToNum(a) + exprTokenToNum(b);
break;
case EXPR_OP_DIFF:
result->num = exprTokenToNum(a) - exprTokenToNum(b);
break;
case EXPR_OP_GT:
result->num = exprTokenToNum(a) > exprTokenToNum(b) ? 1 : 0;
break;
case EXPR_OP_GTE:
result->num = exprTokenToNum(a) >= exprTokenToNum(b) ? 1 : 0;
break;
case EXPR_OP_LT:
result->num = exprTokenToNum(a) < exprTokenToNum(b) ? 1 : 0;
break;
case EXPR_OP_LTE:
result->num = exprTokenToNum(a) <= exprTokenToNum(b) ? 1 : 0;
break;
case EXPR_OP_EQ:
result->num = exprTokensEqual(a, b) ? 1 : 0;
break;
case EXPR_OP_NEQ:
result->num = !exprTokensEqual(a, b) ? 1 : 0;
break;
case EXPR_OP_IN: {
// For 'in' operator, b must be a tuple.
result->num = 0; // Default to false.
if (b->token_type == EXPR_TOKEN_TUPLE) {
for (size_t j = 0; j < b->tuple.len; j++) {
if (exprTokensEqual(a, b->tuple.ele[j])) {
result->num = 1; // Found a match.
break;
}
}
}
break;
}
case EXPR_OP_AND:
result->num =
exprTokenToBool(a) != 0 && exprTokenToBool(b) != 0 ? 1 : 0;
break;
case EXPR_OP_OR:
result->num =
exprTokenToBool(a) != 0 || exprTokenToBool(b) != 0 ? 1 : 0;
break;
default:
// Do nothing: we don't want runtime errors.
break;
}
// Free operands and push result.
if (a) exprTokenRelease(a);
exprTokenRelease(b);
exprStackPush(&es->values_stack, result);
}
if (parsed_json) cJSON_Delete(parsed_json);
// Get final result from stack.
exprtoken *final = exprStackPop(&es->values_stack);
if (final == NULL) return 0;
// Convert result to boolean.
int retval = exprTokenToBool(final);
exprTokenRelease(final);
return retval;
}
/* ============================ Simple test main ============================ */
#ifdef TEST_MAIN
void exprPrintToken(exprtoken *t) {
switch(t->token_type) {
case EXPR_TOKEN_EOF:
printf("EOF");
break;
case EXPR_TOKEN_NUM:
printf("NUM:%g", t->num);
break;
case EXPR_TOKEN_STR:
printf("STR:\"%.*s\"", (int)t->str.len, t->str.start);
break;
case EXPR_TOKEN_SELECTOR:
printf("SEL:%.*s", (int)t->str.len, t->str.start);
break;
case EXPR_TOKEN_OP:
printf("OP:");
for (int i = 0; ExprOptable[i].opname != NULL; i++) {
if (ExprOptable[i].opcode == t->opcode) {
printf("%s", ExprOptable[i].opname);
break;
}
}
break;
default:
printf("UNKNOWN");
break;
}
}
void exprPrintStack(exprstack *stack, const char *name) {
printf("%s (%d items):", name, stack->numitems);
for (int j = 0; j < stack->numitems; j++) {
printf(" ");
exprPrintToken(stack->items[j]);
}
printf("\n");
}
int main(int argc, char **argv) {
char *testexpr = "(5+2)*3 and .year > 1980 and 'foo' == 'foo'";
char *testjson = "{\"year\": 1984, \"name\": \"The Matrix\"}";
if (argc >= 2) testexpr = argv[1];
if (argc >= 3) testjson = argv[2];
printf("Compiling expression: %s\n", testexpr);
int errpos = 0;
exprstate *es = exprCompile(testexpr,&errpos);
if (es == NULL) {
printf("Compilation failed near \"...%s\"\n", testexpr+errpos);
return 1;
}
exprPrintStack(&es->tokens, "Tokens");
exprPrintStack(&es->program, "Program");
printf("Running against object: %s\n", testjson);
int result = exprRun(es,testjson,strlen(testjson));
printf("Result1: %s\n", result ? "True" : "False");
result = exprRun(es,testjson,strlen(testjson));
printf("Result2: %s\n", result ? "True" : "False");
exprFree(es);
return 0;
}
#endif

113
hnsw.c
View File

@ -70,9 +70,9 @@
* orphaned of one link. */
void (*hfree)(void *p) = free;
void *(*hmalloc)(size_t s) = malloc;
void *(*hrealloc)(void *old, size_t s) = realloc;
static void (*hfree)(void *p) = free;
static void *(*hmalloc)(size_t s) = malloc;
static void *(*hrealloc)(void *old, size_t s) = realloc;
void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t),
void *(*realloc_ptr)(void*, size_t))
@ -618,9 +618,19 @@ void hnsw_add_node(HNSW *index, hnswNode *node) {
}
/* Search the specified layer starting from the specified entry point
* to collect 'ef' nodes that are near to 'query'. */
pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
uint32_t ef, uint32_t layer, uint32_t slot)
* to collect 'ef' nodes that are near to 'query'.
*
* This function implements optional hybrid search, so that each node
* can be accepted or not based on its associated value. In this case
* a callback 'filter_callback' should be passed, together with a maximum
* effort for the search (number of candidates to evaluate), since even
* with a a low "EF" value we risk that there are too few nodes that satisfy
* the provided filter, and we could trigger a full scan. */
pqueue *search_layer_with_filter(
HNSW *index, hnswNode *query, hnswNode *entry_point,
uint32_t ef, uint32_t layer, uint32_t slot,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates)
{
// Mark visited nodes with a never seen epoch.
index->current_epoch[slot]++;
@ -633,27 +643,33 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
return NULL;
}
// Take track of the total effort: only used when filtering via
// a callback to have a bound effort.
uint32_t evaluated_candidates = 1;
// Add entry point.
float dist = hnsw_distance(index, query, entry_point);
pq_push(candidates, entry_point, dist);
pq_push(results, entry_point, dist);
if (filter_callback == NULL ||
filter_callback(entry_point->value, filter_privdata))
{
pq_push(results, entry_point, dist);
}
entry_point->visited_epoch[slot] = index->current_epoch[slot];
// Process candidates.
while (candidates->count > 0) {
// Pop closest element and use its saved distance.
// Max effort. If zero, we keep scanning.
if (filter_callback &&
max_candidates &&
evaluated_candidates >= max_candidates) break;
float cur_dist;
hnswNode *current = pq_pop(candidates, &cur_dist);
evaluated_candidates++;
/* Stop if we can't get better results. Note that this can
* be true only if we already collected 'ef' elements in
* the priority queue. This is why: if we have less than EF
* elements, later in the for loop that checks the neighbors we
* add new elements BOTH in the results and candidates pqueue: this
* means that before accumulating EF elements, the worst candidate
* can be as bad as the worst result, but not worse. */
float furthest = pq_max_distance(results);
if (cur_dist > furthest) break;
if (results->count >= ef && cur_dist > furthest) break;
/* Check neighbors. */
for (uint32_t i = 0; i < current->layers[layer].num_links; i++) {
@ -664,11 +680,29 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
neighbor->visited_epoch[slot] = index->current_epoch[slot];
float neighbor_dist = hnsw_distance(index, query, neighbor);
// Add to results if better than current max or results not full.
furthest = pq_max_distance(results);
if (neighbor_dist < furthest || results->count < ef) {
pq_push(candidates, neighbor, neighbor_dist);
pq_push(results, neighbor, neighbor_dist);
if (filter_callback == NULL) {
/* Original HNSW logic when no filtering:
* Add to results if better than current max or
* results not full. */
if (neighbor_dist < furthest || results->count < ef) {
pq_push(candidates, neighbor, neighbor_dist);
pq_push(results, neighbor, neighbor_dist);
}
} else {
/* With filtering: we add candidates even if doesn't match
* the filter, in order to continue to explore the graph. */
if (neighbor_dist < furthest || candidates->count < ef) {
pq_push(candidates, neighbor, neighbor_dist);
}
/* Add results only if passes filter. */
if (filter_callback(neighbor->value, filter_privdata)) {
if (neighbor_dist < furthest || results->count < ef) {
pq_push(results, neighbor, neighbor_dist);
}
}
}
}
}
@ -677,6 +711,14 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
return results;
}
/* Just a wrapper without hybrid search callback. */
pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
uint32_t ef, uint32_t layer, uint32_t slot)
{
return search_layer_with_filter(index, query, entry_point, ef, layer, slot,
NULL, NULL, 0);
}
/* This function is used in order to initialize a node allocated in the
* function stack with the specified vector. The idea is that we can
* easily use hnsw_distance() from a vector and the HNSW nodes this way:
@ -738,10 +780,21 @@ void hnsw_free_tmp_node(hnswNode *node, const float *vector) {
* arrays must have space for at least 'k' items.
* norm_query should be set to 1 if the query vector is already
* normalized, otherwise, if 0, the function will copy the vector,
* L2-normalize the copy and search using the normalized version. */
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
* L2-normalize the copy and search using the normalized version.
*
* If the filter_privdata callback is passed, only elements passing the
* specified filter (invoked with privdata and the value associated
* to the node as arguments) are returned. In such case, if max_candidates
* is not NULL, it represents the maximum number of nodes to explore, since
* the search may be otherwise unbound if few or no elements pass the
* filter. */
int hnsw_search_with_filter
(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized)
int query_vector_is_normalized,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates)
{
if (!index || !query_vector || !neighbors || k == 0) return -1;
if (!index->enter_point) return 0; // Empty index.
@ -769,7 +822,9 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
}
/* Search bottom layer (the most densely populated) with ef = k */
pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot);
pqueue *results = search_layer_with_filter(
index, &query, curr_ep, k, 0, slot, filter_callback,
filter_privdata, max_candidates);
if (!results) {
hnsw_free_tmp_node(&query, query_vector);
return -1;
@ -789,6 +844,16 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
return found;
}
/* Wrapper to hnsw_search_with_filter() when no filter is needed. */
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized)
{
return hnsw_search_with_filter(index,query_vector,k,neighbors,
distances,slot,query_vector_is_normalized,
NULL,NULL,0);
}
/* Rescan a node and update the wortst neighbor index.
* The followinng two functions are variants of this function to be used
* when links are added or removed: they may do less work than a full scan. */

6
hnsw.h
View File

@ -119,6 +119,12 @@ hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector,
int hnsw_search(HNSW *index, const float *query, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized);
int hnsw_search_with_filter
(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates);
void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec);
void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value));

177
tests/filter_expr.py Normal file
View File

@ -0,0 +1,177 @@
from test import TestCase
class VSIMFilterExpressions(TestCase):
def getname(self):
return "VSIM FILTER expressions basic functionality"
def test(self):
# Create a small set of vectors with different attributes
# Basic vectors for testing - all orthogonal for clear results
vec1 = [1, 0, 0, 0]
vec2 = [0, 1, 0, 0]
vec3 = [0, 0, 1, 0]
vec4 = [0, 0, 0, 1]
vec5 = [0.5, 0.5, 0, 0]
# Add vectors with various attributes
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1], f'{self.test_key}:item:1')
self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1',
'{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}')
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec2], f'{self.test_key}:item:2')
self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2',
'{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}')
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec3], f'{self.test_key}:item:3')
self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:3',
'{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}')
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec4], f'{self.test_key}:item:4')
# Item 4 has no attribute at all
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec5], f'{self.test_key}:item:5')
self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:5',
'invalid json') # Intentionally malformed JSON
# Test 1: Basic equality with numbers
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age == 25')
assert len(result) == 1, "Expected 1 result for age == 25"
assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for age == 25"
# Test 2: Greater than
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age > 25')
assert len(result) == 2, "Expected 2 results for age > 25"
# Test 3: Less than or equal
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age <= 30')
assert len(result) == 2, "Expected 2 results for age <= 30"
# Test 4: String equality
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.name == "Alice"')
assert len(result) == 1, "Expected 1 result for name == Alice"
assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for name == Alice"
# Test 5: String inequality
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.name != "Alice"')
assert len(result) == 2, "Expected 2 results for name != Alice"
# Test 6: Boolean value
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.active')
assert len(result) == 1, "Expected 1 result for .active being true"
# Test 7: Logical AND
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age > 20 and .age < 30')
assert len(result) == 1, "Expected 1 result for 20 < age < 30"
assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for 20 < age < 30"
# Test 8: Logical OR
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age < 30 or .age > 35')
assert len(result) == 1, "Expected 1 result for age < 30 or age > 35"
# Test 9: Logical NOT
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '!(.age == 25)')
assert len(result) == 2, "Expected 2 results for NOT(age == 25)"
# Test 10: The "in" operator with array
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age in [25, 35]')
assert len(result) == 2, "Expected 2 results for age in [25, 35]"
# Test 11: The "in" operator with strings in array
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.name in ["Alice", "David"]')
assert len(result) == 1, "Expected 1 result for name in [Alice, David]"
# Test 12: Arithmetic operations - addition
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age + 10 > 40')
assert len(result) == 1, "Expected 1 result for age + 10 > 40"
# Test 13: Arithmetic operations - multiplication
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age * 2 > 60')
assert len(result) == 1, "Expected 1 result for age * 2 > 60"
# Test 14: Arithmetic operations - division
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age / 5 == 5')
assert len(result) == 1, "Expected 1 result for age / 5 == 5"
# Test 15: Arithmetic operations - modulo
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age % 2 == 0')
assert len(result) == 1, "Expected 1 result for age % 2 == 0"
# Test 16: Power operator
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age ** 2 > 900')
assert len(result) == 1, "Expected 1 result for age^2 > 900"
# Test 17: Missing attribute (should exclude items missing that attribute)
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.missing_field == "value"')
assert len(result) == 0, "Expected 0 results for missing_field == value"
# Test 18: No attribute set at all
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.any_field')
assert f'{self.test_key}:item:4' not in [item.decode() for item in result], "Item with no attribute should be excluded"
# Test 19: Malformed JSON
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.any_field')
assert f'{self.test_key}:item:5' not in [item.decode() for item in result], "Item with malformed JSON should be excluded"
# Test 20: Complex expression combining multiple operators
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '(.age > 20 and .age < 40) and (.city == "Boston" or .city == "New York")')
assert len(result) == 2, "Expected 2 results for the complex expression"
expected_items = [f'{self.test_key}:item:1', f'{self.test_key}:item:2']
assert set([item.decode() for item in result]) == set(expected_items), "Expected item:1 and item:2 for the complex expression"
# Test 21: Parentheses to control operator precedence
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.age > (20 + 10)')
assert len(result) == 1, "Expected 1 result for age > (20 + 10)"
# Test 22: Array access (arrays evaluate to true)
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1],
'FILTER', '.scores')
assert len(result) == 3, "Expected 3 results for .scores (arrays evaluate to true)"

668
tests/filter_int.py Normal file
View File

@ -0,0 +1,668 @@
from test import TestCase, generate_random_vector
import struct
import random
import math
import json
import time
class VSIMFilterAdvanced(TestCase):
def getname(self):
return "VSIM FILTER comprehensive functionality testing"
def estimated_runtime(self):
return 15 # This test might take up to 15 seconds for the large dataset
def setup(self):
super().setup()
self.dim = 32 # Vector dimension
self.count = 5000 # Number of vectors for large tests
self.small_count = 50 # Number of vectors for small/quick tests
# Categories for attributes
self.categories = ["electronics", "furniture", "clothing", "books", "food"]
self.cities = ["New York", "London", "Tokyo", "Paris", "Berlin", "Sydney", "Toronto", "Singapore"]
self.price_ranges = [(10, 50), (50, 200), (200, 1000), (1000, 5000)]
self.years = list(range(2000, 2025))
def create_attributes(self, index):
"""Create realistic attributes for a vector"""
category = random.choice(self.categories)
city = random.choice(self.cities)
min_price, max_price = random.choice(self.price_ranges)
price = round(random.uniform(min_price, max_price), 2)
year = random.choice(self.years)
in_stock = random.random() > 0.3 # 70% chance of being in stock
rating = round(random.uniform(1, 5), 1)
views = int(random.expovariate(1/1000)) # Exponential distribution for page views
tags = random.sample(["popular", "sale", "new", "limited", "exclusive", "clearance"],
k=random.randint(0, 3))
# Add some specific patterns for testing
# Every 10th item has a specific property combination for testing
is_premium = (index % 10 == 0)
# Create attributes dictionary
attrs = {
"id": index,
"category": category,
"location": city,
"price": price,
"year": year,
"in_stock": in_stock,
"rating": rating,
"views": views,
"tags": tags
}
if is_premium:
attrs["is_premium"] = True
attrs["special_features"] = ["premium", "warranty", "support"]
# Add sub-categories for more complex filters
if category == "electronics":
attrs["subcategory"] = random.choice(["phones", "computers", "cameras", "audio"])
elif category == "furniture":
attrs["subcategory"] = random.choice(["chairs", "tables", "sofas", "beds"])
elif category == "clothing":
attrs["subcategory"] = random.choice(["shirts", "pants", "dresses", "shoes"])
# Add some intentionally missing fields for testing
if random.random() > 0.9: # 10% chance of missing price
del attrs["price"]
# Some items have promotion field
if random.random() > 0.7: # 30% chance of having a promotion
attrs["promotion"] = random.choice(["discount", "bundle", "gift"])
# Create invalid JSON for a small percentage of vectors
if random.random() > 0.98: # 2% chance of having invalid JSON
return "{{invalid json}}"
return json.dumps(attrs)
def create_vectors_with_attributes(self, key, count):
"""Create vectors and add attributes to them"""
vectors = []
names = []
attribute_map = {} # To store attributes for verification
# Create vectors
for i in range(count):
vec = generate_random_vector(self.dim)
vectors.append(vec)
name = f"{key}:item:{i}"
names.append(name)
# Add to Redis
vec_bytes = struct.pack(f'{self.dim}f', *vec)
self.redis.execute_command('VADD', key, 'FP32', vec_bytes, name)
# Create and add attributes
attrs = self.create_attributes(i)
self.redis.execute_command('VSETATTR', key, name, attrs)
# Store attributes for later verification
try:
attribute_map[name] = json.loads(attrs) if '{' in attrs else None
except json.JSONDecodeError:
attribute_map[name] = None
return vectors, names, attribute_map
def filter_linear_search(self, vectors, names, query_vector, filter_expr, attribute_map, k=10):
"""Perform a linear search with filtering for verification"""
similarities = []
query_norm = math.sqrt(sum(x*x for x in query_vector))
if query_norm == 0:
return []
for i, vec in enumerate(vectors):
name = names[i]
attributes = attribute_map.get(name)
# Skip if doesn't match filter
if not self.matches_filter(attributes, filter_expr):
continue
vec_norm = math.sqrt(sum(x*x for x in vec))
if vec_norm == 0:
continue
dot_product = sum(a*b for a,b in zip(query_vector, vec))
cosine_sim = dot_product / (query_norm * vec_norm)
distance = 1.0 - cosine_sim
redis_similarity = 1.0 - (distance/2.0)
similarities.append((name, redis_similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:k]
def matches_filter(self, attributes, filter_expr):
"""Filter matching for verification - uses Python eval to handle complex expressions"""
if attributes is None:
return False # No attributes or invalid JSON
# Replace JSON path selectors with Python dictionary access
py_expr = filter_expr
# Handle `.field` notation (replace with attributes['field'])
i = 0
while i < len(py_expr):
if py_expr[i] == '.' and (i == 0 or not py_expr[i-1].isalnum()):
# Find the end of the selector (stops at operators or whitespace)
j = i + 1
while j < len(py_expr) and (py_expr[j].isalnum() or py_expr[j] == '_'):
j += 1
if j > i + 1: # Found a valid selector
field = py_expr[i+1:j]
# Use a safe access pattern that returns a default value based on context
py_expr = py_expr[:i] + f"attributes.get('{field}')" + py_expr[j:]
i = i + len(f"attributes.get('{field}')")
else:
i += 1
else:
i += 1
# Convert not operator if needed
py_expr = py_expr.replace('!', ' not ')
try:
# Custom evaluation that handles exceptions for missing fields
# by returning False for the entire expression
# Split the expression on logical operators
parts = []
for op in [' and ', ' or ']:
if op in py_expr:
parts = py_expr.split(op)
break
if not parts: # No logical operators found
parts = [py_expr]
# Try to evaluate each part - if any part fails,
# the whole expression should fail
try:
result = eval(py_expr, {"attributes": attributes})
return bool(result)
except (TypeError, AttributeError):
# This typically happens when trying to compare None with
# numbers or other types, or when an attribute doesn't exist
return False
except Exception as e:
print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}")
return False
except Exception as e:
print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}")
return False
def safe_decode(self,item):
return item.decode() if isinstance(item, bytes) else item
def calculate_recall(self, redis_results, linear_results, k=10):
"""Calculate recall (percentage of correct results retrieved)"""
redis_set = set(self.safe_decode(item) for item in redis_results)
linear_set = set(item[0] for item in linear_results[:k])
if not linear_set:
return 1.0 # If no linear results, consider it perfect recall
intersection = redis_set.intersection(linear_set)
return len(intersection) / len(linear_set)
def test_recall_with_filter(self, filter_expr, ef=500, filter_ef=None):
"""Test recall for a given filter expression"""
# Create query vector
query_vec = generate_random_vector(self.dim)
# First, get ground truth using linear scan
linear_results = self.filter_linear_search(
self.vectors, self.names, query_vec, filter_expr, self.attribute_map, k=50)
# Calculate true selectivity from ground truth
true_selectivity = len(linear_results) / len(self.names) if self.names else 0
# Perform Redis search with filter
cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
cmd_args.extend([str(x) for x in query_vec])
cmd_args.extend(['COUNT', 50, 'WITHSCORES', 'EF', ef, 'FILTER', filter_expr])
if filter_ef:
cmd_args.extend(['FILTER-EF', filter_ef])
start_time = time.time()
redis_results = self.redis.execute_command(*cmd_args)
query_time = time.time() - start_time
# Convert Redis results to dict
redis_items = {}
for i in range(0, len(redis_results), 2):
key = redis_results[i].decode() if isinstance(redis_results[i], bytes) else redis_results[i]
score = float(redis_results[i+1])
redis_items[key] = score
# Calculate metrics
recall = self.calculate_recall(redis_items.keys(), linear_results)
selectivity = len(redis_items) / len(self.names) if redis_items else 0
# Compare against the true selectivity from linear scan
assert abs(selectivity - true_selectivity) < 0.1, \
f"Redis selectivity {selectivity:.3f} differs significantly from ground truth {true_selectivity:.3f}"
# We expect high recall for standard parameters
if ef >= 500 and (filter_ef is None or filter_ef >= 1000):
try:
assert recall >= 0.7, \
f"Low recall {recall:.2f} for filter '{filter_expr}'"
except AssertionError as e:
# Get items found in each set
redis_items_set = set(redis_items.keys())
linear_items_set = set(item[0] for item in linear_results)
# Find items in each set
only_in_redis = redis_items_set - linear_items_set
only_in_linear = linear_items_set - redis_items_set
in_both = redis_items_set & linear_items_set
# Build comprehensive debug message
debug = f"\nGround Truth: {len(linear_results)} matching items (total vectors: {len(self.vectors)})"
debug += f"\nRedis Found: {len(redis_items)} items with FILTER-EF: {filter_ef or 'default'}"
debug += f"\nItems in both sets: {len(in_both)} (recall: {recall:.4f})"
debug += f"\nItems only in Redis: {len(only_in_redis)}"
debug += f"\nItems only in Ground Truth: {len(only_in_linear)}"
# Show some example items from each set with their scores
if only_in_redis:
debug += "\n\nTOP 5 ITEMS ONLY IN REDIS:"
sorted_redis = sorted([(k, v) for k, v in redis_items.items()], key=lambda x: x[1], reverse=True)
for i, (item, score) in enumerate(sorted_redis[:5]):
if item in only_in_redis:
debug += f"\n {i+1}. {item} (Score: {score:.4f})"
# Show attribute that should match filter
attr = self.attribute_map.get(item)
if attr:
debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}"
if only_in_linear:
debug += "\n\nTOP 5 ITEMS ONLY IN GROUND TRUTH:"
for i, (item, score) in enumerate(linear_results[:5]):
if item in only_in_linear:
debug += f"\n {i+1}. {item} (Score: {score:.4f})"
# Show attribute that should match filter
attr = self.attribute_map.get(item)
if attr:
debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}"
# Help identify parsing issues
debug += "\n\nPARSING CHECK:"
debug += f"\nRedis command: VSIM {self.test_key} VALUES {self.dim} [...] FILTER '{filter_expr}'"
# Check for WITHSCORES handling issues
if len(redis_results) > 0 and len(redis_results) % 2 == 0:
debug += f"\nRedis returned {len(redis_results)} items (looks like item,score pairs)"
debug += f"\nFirst few results: {redis_results[:4]}"
# Check the filter implementation
debug += "\n\nFILTER IMPLEMENTATION CHECK:"
debug += f"\nFilter expression: '{filter_expr}'"
debug += "\nSample attribute matches from attribute_map:"
count_matching = 0
for i, (name, attrs) in enumerate(self.attribute_map.items()):
if attrs and self.matches_filter(attrs, filter_expr):
count_matching += 1
if i < 3: # Show first 3 matches
debug += f"\n - {name}: {attrs}"
debug += f"\nTotal items matching filter in attribute_map: {count_matching}"
# Check if results array handling could be wrong
debug += "\n\nRESULT ARRAYS CHECK:"
if len(linear_results) >= 1:
debug += f"\nlinear_results[0]: {linear_results[0]}"
if isinstance(linear_results[0], tuple) and len(linear_results[0]) == 2:
debug += " (correct tuple format: (name, score))"
else:
debug += " (UNEXPECTED FORMAT!)"
# Debug sort order
debug += "\n\nSORTING CHECK:"
if len(linear_results) >= 2:
debug += f"\nGround truth first item score: {linear_results[0][1]}"
debug += f"\nGround truth second item score: {linear_results[1][1]}"
debug += f"\nCorrectly sorted by similarity? {linear_results[0][1] >= linear_results[1][1]}"
# Re-raise with detailed information
raise AssertionError(str(e) + debug)
return recall, selectivity, query_time, len(redis_items)
def test(self):
print(f"\nRunning comprehensive VSIM FILTER tests...")
# Create a larger dataset for testing
print(f"Creating dataset with {self.count} vectors and attributes...")
self.vectors, self.names, self.attribute_map = self.create_vectors_with_attributes(
self.test_key, self.count)
# ==== 1. Recall and Precision Testing ====
print("Testing recall for various filters...")
# Test basic filters with different selectivity
results = {}
results["category"] = self.test_recall_with_filter('.category == "electronics"')
results["price_high"] = self.test_recall_with_filter('.price > 1000')
results["in_stock"] = self.test_recall_with_filter('.in_stock')
results["rating"] = self.test_recall_with_filter('.rating >= 4')
results["complex1"] = self.test_recall_with_filter('.category == "electronics" and .price < 500')
print("Filter | Recall | Selectivity | Time (ms) | Results")
print("----------------------------------------------------")
for name, (recall, selectivity, time_ms, count) in results.items():
print(f"{name:7} | {recall:.3f} | {selectivity:.3f} | {time_ms*1000:.1f} | {count}")
# ==== 2. Filter Selectivity Performance ====
print("\nTesting filter selectivity performance...")
# High selectivity (very few matches)
high_sel_recall, _, high_sel_time, _ = self.test_recall_with_filter('.is_premium')
# Medium selectivity
med_sel_recall, _, med_sel_time, _ = self.test_recall_with_filter('.price > 100 and .price < 1000')
# Low selectivity (many matches)
low_sel_recall, _, low_sel_time, _ = self.test_recall_with_filter('.year > 2000')
print(f"High selectivity recall: {high_sel_recall:.3f}, time: {high_sel_time*1000:.1f}ms")
print(f"Med selectivity recall: {med_sel_recall:.3f}, time: {med_sel_time*1000:.1f}ms")
print(f"Low selectivity recall: {low_sel_recall:.3f}, time: {low_sel_time*1000:.1f}ms")
# ==== 3. FILTER-EF Parameter Testing ====
print("\nTesting FILTER-EF parameter...")
# Test with different FILTER-EF values
filter_expr = '.category == "electronics" and .price > 200'
ef_values = [100, 500, 2000, 5000]
print("FILTER-EF | Recall | Time (ms)")
print("-----------------------------")
for filter_ef in ef_values:
recall, _, query_time, _ = self.test_recall_with_filter(
filter_expr, ef=500, filter_ef=filter_ef)
print(f"{filter_ef:9} | {recall:.3f} | {query_time*1000:.1f}")
# Assert that higher FILTER-EF generally gives better recall
low_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=100)
high_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=5000)
# This might not always be true due to randomness, but generally holds
# We use a softer assertion to avoid flaky tests
assert high_ef_recall >= low_ef_recall * 0.8, \
f"Higher FILTER-EF should generally give better recall: {high_ef_recall:.3f} vs {low_ef_recall:.3f}"
# ==== 4. Complex Filter Expressions ====
print("\nTesting complex filter expressions...")
# Test a variety of complex expressions
complex_filters = [
'.price > 100 and (.category == "electronics" or .category == "furniture")',
'(.rating > 4 and .in_stock) or (.price < 50 and .views > 1000)',
'.category in ["electronics", "clothing"] and .price > 200 and .rating >= 3',
'(.category == "electronics" and .subcategory == "phones") or (.category == "furniture" and .price > 1000)',
'.year > 2010 and !(.price < 100) and .in_stock'
]
print("Expression | Results | Time (ms)")
print("-----------------------------")
for i, expr in enumerate(complex_filters):
try:
_, _, query_time, result_count = self.test_recall_with_filter(expr)
print(f"Complex {i+1} | {result_count:7} | {query_time*1000:.1f}")
except Exception as e:
print(f"Complex {i+1} | Error: {str(e)}")
# ==== 5. Attribute Type Testing ====
print("\nTesting different attribute types...")
type_filters = [
('.price > 500', "Numeric"),
('.category == "books"', "String equality"),
('.in_stock', "Boolean"),
('.tags in ["sale", "new"]', "Array membership"),
('.rating * 2 > 8', "Arithmetic")
]
for expr, type_name in type_filters:
try:
_, _, query_time, result_count = self.test_recall_with_filter(expr)
print(f"{type_name:16} | {expr:30} | {result_count:5} results | {query_time*1000:.1f}ms")
except Exception as e:
print(f"{type_name:16} | {expr:30} | Error: {str(e)}")
# ==== 6. Filter + Count Interaction ====
print("\nTesting COUNT parameter with filters...")
filter_expr = '.category == "electronics"'
counts = [5, 20, 100]
for count in counts:
query_vec = generate_random_vector(self.dim)
cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
cmd_args.extend([str(x) for x in query_vec])
cmd_args.extend(['COUNT', count, 'WITHSCORES', 'FILTER', filter_expr])
results = self.redis.execute_command(*cmd_args)
result_count = len(results) // 2 # Divide by 2 because WITHSCORES returns pairs
# We expect result count to be at most the requested count
assert result_count <= count, f"Got {result_count} results with COUNT {count}"
print(f"COUNT {count:3} | Got {result_count:3} results")
# ==== 7. Edge Cases ====
print("\nTesting edge cases...")
# Test with no matching items
no_match_expr = '.category == "nonexistent_category"'
results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim,
*[str(x) for x in generate_random_vector(self.dim)],
'FILTER', no_match_expr)
assert len(results) == 0, f"Expected 0 results for non-matching filter, got {len(results)}"
print(f"No matching items: {len(results)} results (expected 0)")
# Test with invalid filter syntax
try:
self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim,
*[str(x) for x in generate_random_vector(self.dim)],
'FILTER', '.category === "books"') # Triple equals is invalid
assert False, "Expected error for invalid filter syntax"
except:
print("Invalid filter syntax correctly raised an error")
# Test with extremely long complex expression
long_expr = ' and '.join([f'.rating > {i/10}' for i in range(10)])
try:
results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim,
*[str(x) for x in generate_random_vector(self.dim)],
'FILTER', long_expr)
print(f"Long expression: {len(results)} results")
except Exception as e:
print(f"Long expression error: {str(e)}")
print("\nComprehensive VSIM FILTER tests completed successfully")
class VSIMFilterSelectivityTest(TestCase):
def getname(self):
return "VSIM FILTER selectivity performance benchmark"
def estimated_runtime(self):
return 8 # This test might take up to 8 seconds
def setup(self):
super().setup()
self.dim = 32
self.count = 10000
self.test_key = f"{self.test_key}:selectivity" # Use a different key
def create_vector_with_age_attribute(self, name, age):
"""Create a vector with a specific age attribute"""
vec = generate_random_vector(self.dim)
vec_bytes = struct.pack(f'{self.dim}f', *vec)
self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name)
self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps({"age": age}))
def test(self):
print("\nRunning VSIM FILTER selectivity benchmark...")
# Create a dataset where we control the exact selectivity
print(f"Creating controlled dataset with {self.count} vectors...")
# Create vectors with age attributes from 1 to 100
for i in range(self.count):
age = (i % 100) + 1 # Ages from 1 to 100
name = f"{self.test_key}:item:{i}"
self.create_vector_with_age_attribute(name, age)
# Create a query vector
query_vec = generate_random_vector(self.dim)
# Test filters with different selectivities
selectivities = [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.99]
results = []
print("\nSelectivity | Filter | Results | Time (ms)")
print("--------------------------------------------------")
for target_selectivity in selectivities:
# Calculate age threshold for desired selectivity
# For example, age <= 10 gives 10% selectivity
age_threshold = int(target_selectivity * 100)
filter_expr = f'.age <= {age_threshold}'
# Run query and measure time
start_time = time.time()
cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
cmd_args.extend([str(x) for x in query_vec])
cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr])
results = self.redis.execute_command(*cmd_args)
query_time = time.time() - start_time
actual_selectivity = len(results) / min(100, int(target_selectivity * self.count))
print(f"{target_selectivity:.2f} | {filter_expr:15} | {len(results):7} | {query_time*1000:.1f}")
# Add assertion to ensure reasonable performance for different selectivities
# For very selective queries (1%), we might need more exploration
if target_selectivity <= 0.05:
# For very selective queries, ensure we can find some results
assert len(results) > 0, f"No results found for {filter_expr}"
else:
# For less selective queries, performance should be reasonable
assert query_time < 1.0, f"Query too slow: {query_time:.3f}s for {filter_expr}"
print("\nSelectivity benchmark completed successfully")
class VSIMFilterComparisonTest(TestCase):
def getname(self):
return "VSIM FILTER EF parameter comparison"
def estimated_runtime(self):
return 8 # This test might take up to 8 seconds
def setup(self):
super().setup()
self.dim = 32
self.count = 5000
self.test_key = f"{self.test_key}:efparams" # Use a different key
def create_dataset(self):
"""Create a dataset with specific attribute patterns for testing FILTER-EF"""
vectors = []
names = []
# Create vectors with category and quality score attributes
for i in range(self.count):
vec = generate_random_vector(self.dim)
name = f"{self.test_key}:item:{i}"
# Add vector to Redis
vec_bytes = struct.pack(f'{self.dim}f', *vec)
self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name)
# Create attributes - we want a very selective filter
# Only 2% of items have category=premium AND quality>90
category = "premium" if random.random() < 0.1 else random.choice(["standard", "economy", "basic"])
quality = random.randint(1, 100)
attrs = {
"id": i,
"category": category,
"quality": quality
}
self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs))
vectors.append(vec)
names.append(name)
return vectors, names
def test(self):
print("\nRunning VSIM FILTER-EF parameter comparison...")
# Create dataset
vectors, names = self.create_dataset()
# Create a selective filter that matches ~2% of items
filter_expr = '.category == "premium" and .quality > 90'
# Create query vector
query_vec = generate_random_vector(self.dim)
# Test different FILTER-EF values
ef_values = [50, 100, 500, 1000, 5000]
results = []
print("\nFILTER-EF | Results | Time (ms) | Notes")
print("---------------------------------------")
baseline_count = None
for ef in ef_values:
# Run query and measure time
start_time = time.time()
cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim]
cmd_args.extend([str(x) for x in query_vec])
cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr, 'FILTER-EF', ef])
query_results = self.redis.execute_command(*cmd_args)
query_time = time.time() - start_time
# Set baseline for comparison
if baseline_count is None:
baseline_count = len(query_results)
recall_rate = len(query_results) / max(1, baseline_count) if baseline_count > 0 else 1.0
notes = ""
if ef == 5000:
notes = "Baseline"
elif recall_rate < 0.5:
notes = "Low recall!"
print(f"{ef:9} | {len(query_results):7} | {query_time*1000:.1f} | {notes}")
results.append((ef, len(query_results), query_time))
# If we have enough results at highest EF, check that recall improves with higher EF
if results[-1][1] >= 5: # At least 5 results for highest EF
# Extract result counts
result_counts = [r[1] for r in results]
# The last result (highest EF) should typically find more results than the first (lowest EF)
# but we use a soft assertion to avoid flaky tests
assert result_counts[-1] >= result_counts[0], \
f"Higher FILTER-EF should find at least as many results: {result_counts[-1]} vs {result_counts[0]}"
print("\nFILTER-EF parameter comparison completed successfully")

300
vset.c
View File

@ -20,6 +20,10 @@
#include <pthread.h>
#include "hnsw.h"
// We inline directly the expression implementation here so that building
// the module is trivial.
#include "expr.c"
static RedisModuleType *VectorSetType;
static uint64_t VectorSetTypeNextId = 0;
@ -48,6 +52,15 @@ struct vsetObject {
pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely.
uint64_t id; // Unique ID used by threaded VADD to know the
// object is still the same.
uint64_t numattribs; // Number of nodes associated with an attribute.
};
/* Each node has two associated values: the associated string (the item
* in the set) and potentially a JSON string, that is, the attributes, used
* for hybrid search with the VSIM FILTER option. */
struct vsetNodeVal {
RedisModuleString *item;
RedisModuleString *attrib;
};
/* Create a random projection matrix for dimensionality reduction.
@ -108,13 +121,16 @@ struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type)
o->proj_matrix = NULL;
o->proj_input_size = 0;
o->numattribs = 0;
pthread_rwlock_init(&o->in_use_lock,NULL);
return o;
}
void vectorSetReleaseNodeValue(void *v) {
RedisModule_FreeString(NULL,v);
struct vsetNodeVal *nv = v;
RedisModule_FreeString(NULL,nv->item);
if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib);
RedisModule_Free(nv);
}
/* Free the vector set object. */
@ -142,24 +158,60 @@ const char *vectorSetGetQuantName(struct vsetObject *o) {
*
* Returns 1 if the element was added, or 0 if the element was already there
* and was just updated. */
int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, int update, int ef)
int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef)
{
hnswNode *node = RedisModule_DictGet(o->dict,val,NULL);
if (node != NULL) {
if (update) {
void *old_val = node->value;
struct vsetNodeVal *nv = node->value;
/* Pass NULL as value-free function. We want to reuse
* the old value. */
hnsw_delete_node(o->hnsw, node, NULL);
node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,old_val,ef);
node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef);
RedisModule_DictReplace(o->dict,val,node);
/* If attrib != NULL, the user wants that in case of an update we
* update the attribute as well (otherwise it reamins as it was).
* Note that the order of operations is conceinved so that it
* works in case the old attrib and the new attrib pointer is the
* same. */
if (attrib) {
// Empty attribute string means: unset the attribute during
// the update.
size_t attrlen;
RedisModule_StringPtrLen(attrib,&attrlen);
if (attrlen != 0) {
RedisModule_RetainString(NULL,attrib);
o->numattribs++;
} else {
attrib = NULL;
}
if (nv->attrib) {
o->numattribs--;
RedisModule_FreeString(NULL,nv->attrib);
}
nv->attrib = attrib;
}
}
return 0;
}
node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,val,ef);
if (!node) return 0;
struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
nv->item = val;
nv->attrib = attrib;
node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef);
if (node == NULL) {
// XXX Technically in Redis-land we don't have out of memories as we
// crash. However the HNSW library may fail for error in the locking
// libc call. There is understand if this may actually happen or not.
RedisModule_Free(nv);
return 0;
}
if (attrib != NULL) o->numattribs++;
RedisModule_DictSet(o->dict,val,node);
RedisModule_RetainString(NULL,val);
if (attrib) RedisModule_RetainString(NULL,attrib);
return 1;
}
@ -243,11 +295,10 @@ void *VADD_thread(void *arg) {
RedisModuleBlockedClient *bc = targ[0];
struct vsetObject *vset = targ[1];
float *vec = targ[3];
RedisModuleString *val = targ[4];
int ef = (uint64_t)targ[6];
/* Look for candidates... */
InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, val, ef);
InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, NULL, ef);
targ[5] = ic; // Pass the context to the reply callback.
/* Unblock the client so that our read reply will be invoked. */
@ -268,6 +319,7 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModuleString *val = targ[4];
InsertContext *ic = targ[5];
int ef = (uint64_t)targ[6];
RedisModuleString *attrib = targ[7];
RedisModule_Free(targ);
/* Open the key: there are no guarantees it still exists, or contains
@ -300,12 +352,22 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
/* Otherwise try to insert the new element with the neighbors
* collected in background. If we fail, do it synchronously again
* from scratch. */
// First: allocate the dual-ported value for the node.
struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
nv->item = val;
nv->attrib = attrib;
// Then: insert the node in the HNSW data structure.
hnswNode *newnode;
if ((newnode = hnsw_try_commit_insert(vset->hnsw, ic)) == NULL) {
newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, val, ef);
newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef);
} else {
newnode->value = nv;
}
RedisModule_DictSet(vset->dict,val,newnode);
val = NULL; // Don't free it later.
attrib = NULL; // Dont' free it later.
RedisModule_ReplicateVerbatim(ctx);
}
@ -313,6 +375,7 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
// Whatever happens is a success... :D
RedisModule_ReplyWithLongLong(ctx,1);
if (val) RedisModule_FreeString(ctx,val); // Not added? Free it.
if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it.
RedisModule_Free(vec);
return retval;
}
@ -330,6 +393,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
int cas = 0; // Threaded check-and-set style insert.
long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes.
float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args);
RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB.
if (!vec)
return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification");
@ -350,7 +414,10 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_Free(vec);
return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
}
j++; // skip EF argument.
j++; // skip argument.
} else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) {
attrib = argv[j+1];
j++; // skip argument.
} else if (!strcasecmp(opt, "NOQUANT")) {
quant_type = HNSW_QUANT_NONE;
} else if (!strcasecmp(opt, "BIN")) {
@ -464,8 +531,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
if (!cas) {
/* Insert vector synchronously. */
int added = vectorSetInsert(vset,vec,NULL,0,val,1,ef);
if (added) RedisModule_RetainString(ctx,val);
int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef);
RedisModule_Free(vec);
RedisModule_ReplyWithLongLong(ctx,added);
@ -478,7 +544,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0);
pthread_t tid;
void **targ = RedisModule_Alloc(sizeof(void*)*7);
void **targ = RedisModule_Alloc(sizeof(void*)*8);
targ[0] = bc;
targ[1] = vset;
targ[2] = (void*)(unsigned long)vset->id;
@ -486,7 +552,9 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
targ[4] = val;
targ[5] = NULL; // Used later for insertion context.
targ[6] = (void*)(unsigned long)ef;
targ[7] = attrib;
RedisModule_RetainString(ctx,val);
if (attrib) RedisModule_RetainString(ctx,attrib);
if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) {
pthread_rwlock_unlock(&vset->in_use_lock);
RedisModule_AbortBlock(bc);
@ -499,6 +567,17 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
}
}
/* HNSW callback to filter items according to a predicate function
* (our FILTER expression in this case). */
int vectorSetFilterCallback(void *value, void *privdata) {
exprstate *expr = privdata;
struct vsetNodeVal *nv = value;
if (nv->attrib == NULL) return 0; // No attributes? No match.
size_t json_len;
char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len);
return exprRun(expr,json,json_len);
}
/* Common path for the execution of the VSIM command both threaded and
* not threaded. Note that 'ctx' may be normal context of a thread safe
* context obtained from a blocked client. The locking that is specific
@ -506,7 +585,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
* handles the HNSW locking explicitly. */
void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
float *vec, unsigned long count, float epsilon, unsigned long withscores,
unsigned long ef)
unsigned long ef, exprstate *filter_expr, unsigned long filter_ef)
{
/* In our scan, we can't just collect 'count' elements as
* if count is small we would explore the graph in an insufficient
@ -523,7 +602,12 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef);
float *distances = RedisModule_Alloc(sizeof(float)*ef);
int slot = hnsw_acquire_read_slot(vset->hnsw);
unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
unsigned int found;
if (filter_expr == NULL) {
found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
} else {
found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, distances, slot, 0, vectorSetFilterCallback, filter_expr, filter_ef);
}
hnsw_release_read_slot(vset->hnsw,slot);
RedisModule_Free(vec);
@ -536,7 +620,8 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
for (unsigned int i = 0; i < found && i < count; i++) {
if (distances[i] > epsilon) break;
RedisModule_ReplyWithString(ctx, neighbors[i]->value);
struct vsetNodeVal *nv = neighbors[i]->value;
RedisModule_ReplyWithString(ctx, nv->item);
arraylen++;
if (withscores) {
/* The similarity score is provided in a 0-1 range. */
@ -551,6 +636,7 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
RedisModule_Free(neighbors);
RedisModule_Free(distances);
if (filter_expr) exprFree(filter_expr);
}
/* VSIM thread handling the blocked client request. */
@ -566,6 +652,8 @@ void *VSIM_thread(void *arg) {
float epsilon = *((float*)targ[4]);
unsigned long withscores = (unsigned long)targ[5];
unsigned long ef = (unsigned long)targ[6];
exprstate *filter_expr = targ[7];
unsigned long filter_ef = (unsigned long)targ[8];
RedisModule_Free(targ[4]);
RedisModule_Free(targ);
@ -573,7 +661,7 @@ void *VSIM_thread(void *arg) {
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc);
// Run the query.
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef);
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
// Cleanup.
RedisModule_FreeThreadSafeContext(ctx);
@ -582,7 +670,7 @@ void *VSIM_thread(void *arg) {
return NULL;
}
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */
int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx);
@ -596,6 +684,10 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
long long ef = 0; /* Exploration factor (see HNSW paper) */
double epsilon = 2.0; /* Max cosine distance */
/* Things computed later. */
long long filter_ef = 0;
exprstate *filter_expr = NULL;
/* Get key and vector type */
RedisModuleString *key = argv[1];
const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL);
@ -704,6 +796,28 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
}
j += 2;
} else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) {
if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) !=
REDISMODULE_OK || filter_ef <= 0)
{
RedisModule_Free(vec);
return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF");
}
j += 2;
} else if (!strcasecmp(opt, "FILTER") && j+1 < argc) {
RedisModuleString *exprarg = argv[j+1];
size_t exprlen;
char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen);
int errpos;
filter_expr = exprCompile(exprstr,&errpos);
if (filter_expr == NULL) {
if ((size_t)errpos >= exprlen) errpos = 0;
RedisModule_Free(vec);
return RedisModule_ReplyWithErrorFormat(ctx,
"ERR syntax error in FILTER expression near: %s",
exprstr+errpos);
}
j += 2;
} else {
RedisModule_Free(vec);
return RedisModule_ReplyWithError(ctx,
@ -712,6 +826,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
}
int threaded_request = 1; // Run on a thread, by default.
if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes.
// Disable threaded for MULTI/EXEC and Lua.
if (RedisModule_GetContextFlags(ctx) &
@ -737,7 +852,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0);
pthread_t tid;
void **targ = RedisModule_Alloc(sizeof(void*)*7);
void **targ = RedisModule_Alloc(sizeof(void*)*9);
targ[0] = bc;
targ[1] = vset;
targ[2] = vec;
@ -746,16 +861,18 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
*((float*)targ[4]) = epsilon;
targ[5] = (void*)(unsigned long)withscores;
targ[6] = (void*)(unsigned long)ef;
targ[7] = (void*)filter_expr;
targ[8] = (void*)(unsigned long)filter_ef;
if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) {
pthread_rwlock_unlock(&vset->in_use_lock);
RedisModule_AbortBlock(bc);
RedisModule_Free(vec);
RedisModule_Free(targ[4]);
RedisModule_Free(targ);
return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread");
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
}
} else {
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef);
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
}
return REDISMODULE_OK;
@ -839,6 +956,8 @@ int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
/* Remove from HNSW graph using the high-level API that handles
* locking and cleanup. We pass RedisModule_FreeString as the value
* free function since the strings were retained at insertion time. */
struct vsetNodeVal *nv = node->value;
if (nv->attrib != NULL) vset->numattribs--;
hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue);
/* Destroy empty vector set. */
@ -917,6 +1036,92 @@ int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
return REDISMODULE_OK;
}
/* VSETATTR key element json
* Set or remove the JSON attribute associated with an element.
* Setting an empty string removes the attribute.
* The command returns one if the attribute was actually updated or
* zero if there is no key or element. */
int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx);
if (argc != 4) return RedisModule_WrongArity(ctx);
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1],
REDISMODULE_READ|REDISMODULE_WRITE);
int type = RedisModule_KeyType(key);
if (type == REDISMODULE_KEYTYPE_EMPTY)
return RedisModule_ReplyWithLongLong(ctx, 0);
if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL);
if (!node)
return RedisModule_ReplyWithLongLong(ctx, 0);
struct vsetNodeVal *nv = node->value;
RedisModuleString *new_attr = argv[3];
/* Set or delete the attribute based on the fact it's an empty
* string or not. */
size_t attrlen;
RedisModule_StringPtrLen(new_attr, &attrlen);
if (attrlen == 0) {
// If we had an attribute before, decrease the count and free it.
if (nv->attrib) {
vset->numattribs--;
RedisModule_FreeString(NULL, nv->attrib);
nv->attrib = NULL;
}
} else {
// If we didn't have an attribute before, increase the count.
// Otherwise free the old one.
if (nv->attrib) {
RedisModule_FreeString(NULL, nv->attrib);
} else {
vset->numattribs++;
}
// Set new attribute.
RedisModule_RetainString(NULL, new_attr);
nv->attrib = new_attr;
}
RedisModule_ReplyWithLongLong(ctx, 1);
RedisModule_ReplicateVerbatim(ctx);
return REDISMODULE_OK;
}
/* VGETATTR key element
* Get the JSON attribute associated with an element.
* Returns NIL if the element has no attribute or doesn't exist. */
int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx);
if (argc != 3) return RedisModule_WrongArity(ctx);
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
int type = RedisModule_KeyType(key);
if (type == REDISMODULE_KEYTYPE_EMPTY)
return RedisModule_ReplyWithNull(ctx);
if (RedisModule_ModuleTypeGetType(key) != VectorSetType)
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL);
if (!node)
return RedisModule_ReplyWithNull(ctx);
struct vsetNodeVal *nv = node->value;
if (!nv->attrib)
return RedisModule_ReplyWithNull(ctx);
return RedisModule_ReplyWithString(ctx, nv->attrib);
}
/* ============================== Reflection ================================ */
/* VLINKS key element [WITHSCORES]
@ -971,7 +1176,8 @@ int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
/* Add each neighbor's element value to the array. */
for (uint32_t j = 0; j < node->layers[i].num_links; j++) {
RedisModule_ReplyWithString(ctx, node->layers[i].links[j]->value);
struct vsetNodeVal *nv = node->layers[i].links[j]->value;
RedisModule_ReplyWithString(ctx, nv->item);
if (withscores) {
float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]);
/* Convert distance to similarity score to match
@ -1035,6 +1241,9 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
/* ============================== vset type methods ========================= */
#define SAVE_FLAG_HAS_PROJMATRIX (1<<0)
#define SAVE_FLAG_HAS_ATTRIBS (1<<1)
/* Save object to RDB */
void VectorSetRdbSave(RedisModuleIO *rdb, void *value) {
struct vsetObject *vset = value;
@ -1042,9 +1251,13 @@ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) {
RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count);
RedisModule_SaveUnsigned(rdb, vset->hnsw->quant_type);
uint32_t save_flags = 0;
if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX;
if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS;
RedisModule_SaveUnsigned(rdb, save_flags);
/* Save projection matrix if present */
if (vset->proj_matrix) {
RedisModule_SaveUnsigned(rdb, 1); // has projection
uint32_t input_dim = vset->proj_input_size;
uint32_t output_dim = vset->hnsw->vector_dim;
RedisModule_SaveUnsigned(rdb, input_dim);
@ -1054,13 +1267,18 @@ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) {
// Save projection matrix as binary blob
size_t matrix_size = sizeof(float) * input_dim * output_dim;
RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size);
} else {
RedisModule_SaveUnsigned(rdb, 0); // no projection
}
hnswNode *node = vset->hnsw->head;
while(node) {
RedisModule_SaveString(rdb, node->value);
struct vsetNodeVal *nv = node->value;
RedisModule_SaveString(rdb, nv->item);
if (vset->numattribs) {
if (nv->attrib)
RedisModule_SaveString(rdb, nv->attrib);
else
RedisModule_SaveStringBuffer(rdb, "", 0);
}
hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node);
RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size);
RedisModule_SaveUnsigned(rdb, sn->params_count);
@ -1083,7 +1301,9 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
if (!vset) return NULL;
/* Load projection matrix if present */
uint32_t has_projection = RedisModule_LoadUnsigned(rdb);
uint32_t save_flags = RedisModule_LoadUnsigned(rdb);
int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX;
int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS;
if (has_projection) {
uint32_t input_dim = RedisModule_LoadUnsigned(rdb);
uint32_t output_dim = dim;
@ -1105,6 +1325,16 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
while(elements--) {
// Load associated string element.
RedisModuleString *ele = RedisModule_LoadString(rdb);
RedisModuleString *attrib = NULL;
if (has_attribs) {
attrib = RedisModule_LoadString(rdb);
size_t attrlen;
RedisModule_StringPtrLen(attrib,&attrlen);
if (attrlen == 0) {
RedisModule_FreeString(NULL,attrib);
attrib = NULL;
}
}
size_t vector_len;
void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len);
uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw);
@ -1120,12 +1350,16 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
for (uint32_t j = 0; j < params_count; j++)
params[j] = RedisModule_LoadUnsigned(rdb);
hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, ele);
struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv));
nv->item = ele;
nv->attrib = attrib;
hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv);
if (node == NULL) {
RedisModule_LogIOError(rdb,"warning",
"Vector set node index loading error");
return NULL; // Loading error.
}
if (nv->attrib) vset->numattribs++;
RedisModule_DictSet(vset->dict,ele,node);
RedisModule_Free(vector);
RedisModule_Free(params);
@ -1279,6 +1513,14 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx, "VSETATTR",
VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR)
return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx, "VGETATTR",
VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
return REDISMODULE_ERR;
hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc,
RedisModule_Realloc);