diff --git a/LICENSE b/LICENSE index df7a7a7cd..79fb7e399 100644 --- a/LICENSE +++ b/LICENSE @@ -1,2 +1,2 @@ -This code is Copyright (C) 2024-2025 Salvatore Sanfilippo. +This code is Copyright (c) 2024-Present, Redis Ltd. All Rights Reserved. diff --git a/Makefile b/Makefile index 4478ac5d2..ed069801d 100644 --- a/Makefile +++ b/Makefile @@ -53,9 +53,9 @@ all: vset.so .c.xo: $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ -vset.xo: redismodule.h +vset.xo: redismodule.h expr.c -vset.so: vset.xo hnsw.xo +vset.so: vset.xo hnsw.xo cJSON.xo $(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc # Example sources / objects diff --git a/README.md b/README.md index 0702c1f5b..b7b621c11 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,14 @@ This module implements Vector Sets for Redis, a new Redis data type similar to Sorted Sets but having string elements associated to a vector instead of a score. The fundamental goal of Vector Sets is to make possible adding items, and later get a subset of the added items that are the most similar to a -specified vector (often a learned embedding) of the most similar to the vector +specified vector (often a learned embedding), or the most similar to the vector of an element that is already part of the Vector Set. +Moreover, Vector sets implement optional hybrid search capabilities: it is possible to associate attributes to all or to a subset of elements in the set, and then, using the `FILTER` option of the `VSIM` command, to ask for items similar to a given vector but also passing a filter specified as a simple mathematical expression (Like `".year > 1950"` or similar). + ## Installation -Buil with: +Build with: make @@ -27,7 +29,8 @@ The execute the tests with: **VADD: add items into a vector set** - VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN] [EF build-exploration-factor] + VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT | Q8 | BIN] + [EF build-exploration-factor] [SETATTR ] Add a new element into the vector set specified by the key. The vector can be provided as FP32 blob of values, or as floating point @@ -35,27 +38,31 @@ numbers as strings, prefixed by the number of elements (3 in the example): VADD mykey VALUES 3 0.1 1.2 0.5 my-element -The `REDUCE` option implements random projection, in order to reduce the +Meaning of the options: + +`REDUCE` implements random projection, in order to reduce the dimensionality of the vector. The projection matrix is saved and reloaded along with the vector set. -The `CAS` option performs the operation partially using threads, in a +`CAS` performs the operation partially using threads, in a check-and-set style. The neighbor candidates collection, which is slow, is performed in the background, while the command is executed in the main thread. -The `NOQUANT` option forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default. +`NOQUANT` forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default. -The `BIN` option forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality. +`BIN` forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality. -The `Q8` option forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format. +`Q8` forces the vector to use signed 8 bit quantization. This is the default, and the option only exists in order to make sure to check at insertion time if the vector set is of the same format. -The `EF` option plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches. +`EF` plays a role in the effort made to find good candidates when connecting the new node to the existing HNSW graph. The default is 200. Using a larger value, may help to have a better recall. To improve the recall it is also possible to increase `EF` during `VSIM` searches. + +`SETATTR` associates attributes to the newly created entry or update the entry attributes (if it already exists). It is the same as calling the `VSETATTR` attribute separately, so please check the documentation of that command in the hybrid search section of this documentation. **VSIM: return elements by vector similarity** - VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EF search-exploration-factor] + VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] -The command returns similar vectors, in the example instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements associated with a vector similar to a given element already in the sorted set: +The command returns similar vectors, for simplicity (and verbosity) in the following example, instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements having a vector similar to a given element already in the sorted set: > VSIM word_embeddings ELE apple 1) "apple" @@ -81,6 +88,8 @@ It is possible to specify a `COUNT` and also to get the similarity score (from 1 The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000. +For `FILTER` and `FILTER-EF` options, please check the hybrid search section of this documentation. + **VDIM: return the dimension of the vectors inside the vector set** VDIM keyname @@ -180,17 +189,218 @@ Example: 11) hnsw-max-node-uid 12) (integer) 3000000 -## Known bugs +**VSETATTR: associate or remove the JSON attributes of elements** + + VSETATTR key element "{... json ...}" + +Each element of a vector set can be optionally associated with a JSON string +in order to use the `FILTER` option of `VSIM` to filter elements by scalars +(see the hybrid search section for more information). This command can set, +update (if already set) or delete (if you set to an empty string) the +associated JSON attributes of an element. + +The command returns 0 if the element or the key don't exist, without +raising an error, otherwise 1 is returned, and the element attributes +are set or updated. + +**VGETATTR: retrieve the JSON attributes of elements** + + VGET key element + +The command returns the JSON attribute associated with an element, or +null if there is no element associated, or no element at all, or no key. + +# Hybrid search + +Each element of the vector set can be associated with a set of attributes specified as a JSON blob: + + > VADD vset VALUES 3 1 1 1 a SETATTR '{"year": 1950}' + (integer) 1 + > VADD vset VALUES 3 -1 -1 -1 b SETATTR '{"year": 1951}' + (integer) 1 + +Specifying an attribute with the `SETATTR` option of `VADD` is exactly equivalent to adding an element and then setting (or updating, if already set) the attributes JSON string. Also the symmetrical `VGETATTR` command returns the attribute associated to a given element. + + > VAD vset VALUES 3 0 1 0 c + (integer) 1 + > VSETATTR vset c '{"year": 1952}' + (integer) 1 + > VGETATTR vset c + "{\"year\": 1952}" + +At this point, I may use the FILTER option of VSIM to only ask for the subset of elements that are verified by my expression: + + > VSIM vset VALUES 3 0 0 0 FILTER '.year > 1950' + 1) "c" + 2) "b" + +The items will be returned again in order of similarity (most similar first), but only the items with the year field matching the expression is returned. + +The expressions are similar to what you would write inside the `if` statement of JavaScript or other familiar programming languages: you can use `and`, `or`, the obvious math operators like `+`, `-`, `/`, `>=`, `<`, ... and so forth (see the expressions section for more info). The selectors of the JSON object attributes start with a dot followed by the name of the key inside the JSON objects. + +Elements with invalid JSON or not having a given specified field **are considered as not matching** the expression, but will not generate any error at runtime. + +I'll draft the missing sections for the README following the style and format of the existing content. + +## FILTER expressions capabilities + +FILTER expressions allow you to perform complex filtering on vector similarity results using a JavaScript-like syntax. The expression is evaluated against each element's JSON attributes, with only elements that satisfy the expression being included in the results. + +### Expression Syntax + +Expressions support the following operators and capabilities: + +1. **Arithmetic operators**: `+`, `-`, `*`, `/`, `%` (modulo), `**` (exponentiation) +2. **Comparison operators**: `>`, `>=`, `<`, `<=`, `==`, `!=` +3. **Logical operators**: `and`/`&&`, `or`/`||`, `!`/`not` +4. **Containment operator**: `in` +5. **Parentheses** for grouping: `(...)` + +### Selector Notation + +Attributes are accessed using dot notation: + +- `.year` references the "year" attribute +- `.movie.year` would **NOT** reference the "year" field inside a "movie" object, only keys that are at the first level of the JSON object are accessible. + +### JSON and expressions data types + +Expressions can work with: + +- Numbers (dobule precision floats) +- Strings (enclosed in single or double quotes) +- Booleans (no native type: they are represented as 1 for true, 0 for false) +- Arrays (for use with the `in` operator: `value in [1, 2, 3]`) + +JSON attributes are converted in this way: + +- Numbers will be converted to numbers. +- Strings to strings. +- Booleans to 0 or 1 number. +- Arrays to tuples (for "in" operator), but only if composed of just numbers and strings. + +Any other type is ignored, and accessig it will make the expression evaluate to false. + +### Examples + +``` +# Find items from the 1980s +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.year >= 1980 and .year < 1990' + +# Find action movies with high ratings +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.genre == "action" and .rating > 8.0' + +# Find movies directed by either Spielberg or Nolan +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '.director in ["Spielberg", "Nolan"]' + +# Complex condition with numerical operations +VSIM movies VALUES 3 0.5 0.8 0.2 FILTER '(.year - 2000) ** 2 < 100 and .rating / 2 > 4' +``` + +### Error Handling + +Elements with any of the following conditions are considered not matching: +- Missing the queried JSON attribute +- Having invalid JSON in their attributes +- Having a JSON value that cannot be converted to the expected type + +This behavior allows you to safely filter on optional attributes without generating errors. + +### FILTER effort + +The `FILTER-EF` option controls the maximum effort spent when filtering vector search results. + +When performing vector similarity search with filtering, Vector Sets perform the standard similarity search as they apply the filter expression to each node. Since many results might be filtered out, Vector Sets may need to examine a lot more candidates than the requested `COUNT` to ensure sufficient matching results are returned. Actually, if the elements matching the filter are very rare or if there are less than elements matching than the specified count, this would trigger a full scan of the HNSW graph. + +For this reason, by default, the maximum effort is limited to a reasonable amount of nodes explored. + +### Modifying the FILTER effort + +1. By default, Vector Sets will explore up to `COUNT * 100` candidates to find matching results. +2. You can control this exploration with the `FILTER-EF` parameter. +3. A higher `FILTER-EF` value increases the chances of finding all relevant matches at the cost of increased processing time. +4. A `FILTER-EF` of zero will explore as many nodes as needed in order to actually return the number of elements specified by `COUNT`. +5. Even when a high `FILTER-EF` value is specified **the implementation will do a lot less work** if the elements passing the filter are very common, because of the early stop conditions of the HNSW implementation (once the specified amount of elements is reached and the quality check of the other candidates trigger an early stop). + +``` +VSIM key [ELE|FP32|VALUES] COUNT 10 FILTER '.year > 2000' FILTER-EF 500 +``` + +In this example, Vector Sets will examine up to 500 potential nodes. Of course if count is reached before exploring 500 nodes, and the quality checks show that it is not possible to make progresses on similarity, the search is ended sooner. + +### Performance Considerations + +- If you have highly selective filters (few items match), use a higher `FILTER-EF`, or just design your application in order to handle a result set that is smaller than the requested count. Note that anyway the additional elements may be too distant than the query vector. +- For less selective filters, the default should be sufficient. +- Very selective filters with low `FILTER-EF` values may return fewer items than requested. +- Extremely high values may impact performance without significantly improving results. + +The optimal `FILTER-EF` value depends on: +1. The selectivity of your filter. +2. The distribution of your data. +3. The required recall quality. + +A good practice is to start with the default and increase if needed when you observe fewer results than expected. + +### Testing a larg-ish data set + +To really see how things work at scale, you can [download](https://antirez.com/word2vec_with_attribs.rdb) the following dataset: + + wget https://antirez.com/word2vec_with_attribs.rdb + +It contains the 3 million words in Word2Vec having as attribute a JSON with just the length of the word. Because of the length distribution of words in large amounts of texts, where longer words become less and less common, this is ideal to check how filtering behaves with a filter verifying as true with less and less elements in a vector set. + +For instance: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 6" + 1) "pastas" + 2) "rotini" + 3) "gnocci" + 4) "panino" + 5) "salads" + 6) "breads" + 7) "salame" + 8) "sauces" + 9) "cheese" + 10) "fritti" + +This will easily retrieve the desired amount of items (`COUNT` is 10 by default) since there are many items of length 6. However: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" + 1) "skinless_boneless_chicken_breasts" + 2) "boneless_skinless_chicken_breasts" + 3) "Boneless_skinless_chicken_breasts" + +This time even if we asked for 10 items, we only get 3, since the default filter effort will be `10*100 = 1000`. We can tune this giving the effort in an explicit way, with the risk of our query being slower, of course: + + > VSIM word_embeddings_bin ele "pasta" FILTER ".len == 33" FILTER-EF 10000 + 1) "skinless_boneless_chicken_breasts" + 2) "boneless_skinless_chicken_breasts" + 3) "Boneless_skinless_chicken_breasts" + 4) "mozzarella_feta_provolone_cheddar" + 5) "Greatfood.com_R_www.greatfood.com" + 6) "Pepperidge_Farm_Goldfish_crackers" + 7) "Prosecuted_Mobsters_Rebuilt_Dying" + 8) "Crispy_Snacker_Sandwiches_Popcorn" + 9) "risultati_delle_partite_disputate" + 10) "Peppermint_Mocha_Twist_Gingersnap" + +This time we get all the ten items, even if the last one will be quite far from our query vector. We encourage to experiment with this test dataset in order to understand better the dynamics of the implementation and the natural tradeoffs of hybrid search. + +**Keep in mind** that by default, Redis Vector Sets will try to avoid a likely very useless huge scan of the HNSW graph, and will be more happy to return few or no elements at all, since this is almost always what the user actually wants in the context of retrieving *similar* items to the query. + +# Known bugs * When VADD with REDUCE is replicated, we should probably send the replicas the random matrix, in order for VEMB to read the same things. This is not critical, because the behavior of VADD / VSIM should be transparent if you don't look at the transformed vectors, but still probably worth doing. * Replication code is pretty much untested, and very vanilla (replicating the commands verbatim). -## Implementation details +# Implementation details Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality. The main features are: * Proper nodes deletion with relinking. -* 8 bits quantization. +* 8 bits and binary quantization. * Threaded queries. +* Hybrid search with predicate callback. diff --git a/cJSON.c b/cJSON.c new file mode 100644 index 000000000..030311ce5 --- /dev/null +++ b/cJSON.c @@ -0,0 +1,3110 @@ +/* + Copyright (c) 2009-2017 Dave Gamble and cJSON contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +*/ + +/* cJSON */ +/* JSON parser in C. */ + +/* disable warnings about old C89 functions in MSVC */ +#if !defined(_CRT_SECURE_NO_DEPRECATE) && defined(_MSC_VER) +#define _CRT_SECURE_NO_DEPRECATE +#endif + +#ifdef __GNUC__ +#pragma GCC visibility push(default) +#endif +#if defined(_MSC_VER) +#pragma warning (push) +/* disable warning about single line comments in system headers */ +#pragma warning (disable : 4001) +#endif + +#include +#include +#include +#include +#include +#include +#include + +#ifdef ENABLE_LOCALES +#include +#endif + +#if defined(_MSC_VER) +#pragma warning (pop) +#endif +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +#include "cJSON.h" + +/* define our own boolean type */ +#ifdef true +#undef true +#endif +#define true ((cJSON_bool)1) + +#ifdef false +#undef false +#endif +#define false ((cJSON_bool)0) + +/* define isnan and isinf for ANSI C, if in C99 or above, isnan and isinf has been defined in math.h */ +#ifndef isinf +#define isinf(d) (isnan((d - d)) && !isnan(d)) +#endif +#ifndef isnan +#define isnan(d) (d != d) +#endif + +#ifndef NAN +#ifdef _WIN32 +#define NAN sqrt(-1.0) +#else +#define NAN 0.0/0.0 +#endif +#endif + +typedef struct { + const unsigned char *json; + size_t position; +} error; +static error global_error = { NULL, 0 }; + +CJSON_PUBLIC(const char *) cJSON_GetErrorPtr(void) +{ + return (const char*) (global_error.json + global_error.position); +} + +CJSON_PUBLIC(char *) cJSON_GetStringValue(const cJSON * const item) +{ + if (!cJSON_IsString(item)) + { + return NULL; + } + + return item->valuestring; +} + +CJSON_PUBLIC(double) cJSON_GetNumberValue(const cJSON * const item) +{ + if (!cJSON_IsNumber(item)) + { + return (double) NAN; + } + + return item->valuedouble; +} + +/* This is a safeguard to prevent copy-pasters from using incompatible C and header files */ +#if (CJSON_VERSION_MAJOR != 1) || (CJSON_VERSION_MINOR != 7) || (CJSON_VERSION_PATCH != 14) + #error cJSON.h and cJSON.c have different versions. Make sure that both have the same. +#endif + +CJSON_PUBLIC(const char*) cJSON_Version(void) +{ + static char version[15]; + sprintf(version, "%i.%i.%i", CJSON_VERSION_MAJOR, CJSON_VERSION_MINOR, CJSON_VERSION_PATCH); + + return version; +} + +/* Case insensitive string comparison, doesn't consider two NULL pointers equal though */ +static int case_insensitive_strcmp(const unsigned char *string1, const unsigned char *string2) +{ + if ((string1 == NULL) || (string2 == NULL)) + { + return 1; + } + + if (string1 == string2) + { + return 0; + } + + for(; tolower(*string1) == tolower(*string2); (void)string1++, string2++) + { + if (*string1 == '\0') + { + return 0; + } + } + + return tolower(*string1) - tolower(*string2); +} + +typedef struct internal_hooks +{ + void *(CJSON_CDECL *allocate)(size_t size); + void (CJSON_CDECL *deallocate)(void *pointer); + void *(CJSON_CDECL *reallocate)(void *pointer, size_t size); +} internal_hooks; + +#if defined(_MSC_VER) +/* work around MSVC error C2322: '...' address of dllimport '...' is not static */ +static void * CJSON_CDECL internal_malloc(size_t size) +{ + return malloc(size); +} +static void CJSON_CDECL internal_free(void *pointer) +{ + free(pointer); +} +static void * CJSON_CDECL internal_realloc(void *pointer, size_t size) +{ + return realloc(pointer, size); +} +#else +#define internal_malloc malloc +#define internal_free free +#define internal_realloc realloc +#endif + +/* strlen of character literals resolved at compile time */ +#define static_strlen(string_literal) (sizeof(string_literal) - sizeof("")) + +static internal_hooks global_hooks = { internal_malloc, internal_free, internal_realloc }; + +static unsigned char* cJSON_strdup(const unsigned char* string, const internal_hooks * const hooks) +{ + size_t length = 0; + unsigned char *copy = NULL; + + if (string == NULL) + { + return NULL; + } + + length = strlen((const char*)string) + sizeof(""); + copy = (unsigned char*)hooks->allocate(length); + if (copy == NULL) + { + return NULL; + } + memcpy(copy, string, length); + + return copy; +} + +CJSON_PUBLIC(void) cJSON_InitHooks(cJSON_Hooks* hooks) +{ + if (hooks == NULL) + { + /* Reset hooks */ + global_hooks.allocate = malloc; + global_hooks.deallocate = free; + global_hooks.reallocate = realloc; + return; + } + + global_hooks.allocate = malloc; + if (hooks->malloc_fn != NULL) + { + global_hooks.allocate = hooks->malloc_fn; + } + + global_hooks.deallocate = free; + if (hooks->free_fn != NULL) + { + global_hooks.deallocate = hooks->free_fn; + } + + /* use realloc only if both free and malloc are used */ + global_hooks.reallocate = NULL; + if ((global_hooks.allocate == malloc) && (global_hooks.deallocate == free)) + { + global_hooks.reallocate = realloc; + } +} + +/* Internal constructor. */ +static cJSON *cJSON_New_Item(const internal_hooks * const hooks) +{ + cJSON* node = (cJSON*)hooks->allocate(sizeof(cJSON)); + if (node) + { + memset(node, '\0', sizeof(cJSON)); + } + + return node; +} + +/* Delete a cJSON structure. */ +CJSON_PUBLIC(void) cJSON_Delete(cJSON *item) +{ + cJSON *next = NULL; + while (item != NULL) + { + next = item->next; + if (!(item->type & cJSON_IsReference) && (item->child != NULL)) + { + cJSON_Delete(item->child); + } + if (!(item->type & cJSON_IsReference) && (item->valuestring != NULL)) + { + global_hooks.deallocate(item->valuestring); + } + if (!(item->type & cJSON_StringIsConst) && (item->string != NULL)) + { + global_hooks.deallocate(item->string); + } + global_hooks.deallocate(item); + item = next; + } +} + +/* get the decimal point character of the current locale */ +static unsigned char get_decimal_point(void) +{ +#ifdef ENABLE_LOCALES + struct lconv *lconv = localeconv(); + return (unsigned char) lconv->decimal_point[0]; +#else + return '.'; +#endif +} + +typedef struct +{ + const unsigned char *content; + size_t length; + size_t offset; + size_t depth; /* How deeply nested (in arrays/objects) is the input at the current offset. */ + internal_hooks hooks; +} parse_buffer; + +/* check if the given size is left to read in a given parse buffer (starting with 1) */ +#define can_read(buffer, size) ((buffer != NULL) && (((buffer)->offset + size) <= (buffer)->length)) +/* check if the buffer can be accessed at the given index (starting with 0) */ +#define can_access_at_index(buffer, index) ((buffer != NULL) && (((buffer)->offset + index) < (buffer)->length)) +#define cannot_access_at_index(buffer, index) (!can_access_at_index(buffer, index)) +/* get a pointer to the buffer at the position */ +#define buffer_at_offset(buffer) ((buffer)->content + (buffer)->offset) + +/* Parse the input text to generate a number, and populate the result into item. */ +static cJSON_bool parse_number(cJSON * const item, parse_buffer * const input_buffer) +{ + double number = 0; + unsigned char *after_end = NULL; + unsigned char number_c_string[64]; + unsigned char decimal_point = get_decimal_point(); + size_t i = 0; + + if ((input_buffer == NULL) || (input_buffer->content == NULL)) + { + return false; + } + + /* copy the number into a temporary buffer and replace '.' with the decimal point + * of the current locale (for strtod) + * This also takes care of '\0' not necessarily being available for marking the end of the input */ + for (i = 0; (i < (sizeof(number_c_string) - 1)) && can_access_at_index(input_buffer, i); i++) + { + switch (buffer_at_offset(input_buffer)[i]) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '+': + case '-': + case 'e': + case 'E': + number_c_string[i] = buffer_at_offset(input_buffer)[i]; + break; + + case '.': + number_c_string[i] = decimal_point; + break; + + default: + goto loop_end; + } + } +loop_end: + number_c_string[i] = '\0'; + + number = strtod((const char*)number_c_string, (char**)&after_end); + if (number_c_string == after_end) + { + return false; /* parse_error */ + } + + item->valuedouble = number; + + /* use saturation in case of overflow */ + if (number >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (number <= (double)INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)number; + } + + item->type = cJSON_Number; + + input_buffer->offset += (size_t)(after_end - number_c_string); + return true; +} + +/* don't ask me, but the original cJSON_SetNumberValue returns an integer or double */ +CJSON_PUBLIC(double) cJSON_SetNumberHelper(cJSON *object, double number) +{ + if (number >= INT_MAX) + { + object->valueint = INT_MAX; + } + else if (number <= (double)INT_MIN) + { + object->valueint = INT_MIN; + } + else + { + object->valueint = (int)number; + } + + return object->valuedouble = number; +} + +CJSON_PUBLIC(char*) cJSON_SetValuestring(cJSON *object, const char *valuestring) +{ + char *copy = NULL; + /* if object's type is not cJSON_String or is cJSON_IsReference, it should not set valuestring */ + if (!(object->type & cJSON_String) || (object->type & cJSON_IsReference)) + { + return NULL; + } + if (strlen(valuestring) <= strlen(object->valuestring)) + { + strcpy(object->valuestring, valuestring); + return object->valuestring; + } + copy = (char*) cJSON_strdup((const unsigned char*)valuestring, &global_hooks); + if (copy == NULL) + { + return NULL; + } + if (object->valuestring != NULL) + { + cJSON_free(object->valuestring); + } + object->valuestring = copy; + + return copy; +} + +typedef struct +{ + unsigned char *buffer; + size_t length; + size_t offset; + size_t depth; /* current nesting depth (for formatted printing) */ + cJSON_bool noalloc; + cJSON_bool format; /* is this print a formatted print */ + internal_hooks hooks; +} printbuffer; + +/* realloc printbuffer if necessary to have at least "needed" bytes more */ +static unsigned char* ensure(printbuffer * const p, size_t needed) +{ + unsigned char *newbuffer = NULL; + size_t newsize = 0; + + if ((p == NULL) || (p->buffer == NULL)) + { + return NULL; + } + + if ((p->length > 0) && (p->offset >= p->length)) + { + /* make sure that offset is valid */ + return NULL; + } + + if (needed > INT_MAX) + { + /* sizes bigger than INT_MAX are currently not supported */ + return NULL; + } + + needed += p->offset + 1; + if (needed <= p->length) + { + return p->buffer + p->offset; + } + + if (p->noalloc) { + return NULL; + } + + /* calculate new buffer size */ + if (needed > (INT_MAX / 2)) + { + /* overflow of int, use INT_MAX if possible */ + if (needed <= INT_MAX) + { + newsize = INT_MAX; + } + else + { + return NULL; + } + } + else + { + newsize = needed * 2; + } + + if (p->hooks.reallocate != NULL) + { + /* reallocate with realloc if available */ + newbuffer = (unsigned char*)p->hooks.reallocate(p->buffer, newsize); + if (newbuffer == NULL) + { + p->hooks.deallocate(p->buffer); + p->length = 0; + p->buffer = NULL; + + return NULL; + } + } + else + { + /* otherwise reallocate manually */ + newbuffer = (unsigned char*)p->hooks.allocate(newsize); + if (!newbuffer) + { + p->hooks.deallocate(p->buffer); + p->length = 0; + p->buffer = NULL; + + return NULL; + } + + memcpy(newbuffer, p->buffer, p->offset + 1); + p->hooks.deallocate(p->buffer); + } + p->length = newsize; + p->buffer = newbuffer; + + return newbuffer + p->offset; +} + +/* calculate the new length of the string in a printbuffer and update the offset */ +static void update_offset(printbuffer * const buffer) +{ + const unsigned char *buffer_pointer = NULL; + if ((buffer == NULL) || (buffer->buffer == NULL)) + { + return; + } + buffer_pointer = buffer->buffer + buffer->offset; + + buffer->offset += strlen((const char*)buffer_pointer); +} + +/* securely comparison of floating-point variables */ +static cJSON_bool compare_double(double a, double b) +{ + double maxVal = fabs(a) > fabs(b) ? fabs(a) : fabs(b); + return (fabs(a - b) <= maxVal * DBL_EPSILON); +} + +/* Render the number nicely from the given item into a string. */ +static cJSON_bool print_number(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + double d = item->valuedouble; + int length = 0; + size_t i = 0; + unsigned char number_buffer[26] = {0}; /* temporary buffer to print the number into */ + unsigned char decimal_point = get_decimal_point(); + double test = 0.0; + + if (output_buffer == NULL) + { + return false; + } + + /* This checks for NaN and Infinity */ + if (isnan(d) || isinf(d)) + { + length = sprintf((char*)number_buffer, "null"); + } + else + { + /* Try 15 decimal places of precision to avoid nonsignificant nonzero digits */ + length = sprintf((char*)number_buffer, "%1.15g", d); + + /* Check whether the original double can be recovered */ + if ((sscanf((char*)number_buffer, "%lg", &test) != 1) || !compare_double((double)test, d)) + { + /* If not, print with 17 decimal places of precision */ + length = sprintf((char*)number_buffer, "%1.17g", d); + } + } + + /* sprintf failed or buffer overrun occurred */ + if ((length < 0) || (length > (int)(sizeof(number_buffer) - 1))) + { + return false; + } + + /* reserve appropriate space in the output */ + output_pointer = ensure(output_buffer, (size_t)length + sizeof("")); + if (output_pointer == NULL) + { + return false; + } + + /* copy the printed number to the output and replace locale + * dependent decimal point with '.' */ + for (i = 0; i < ((size_t)length); i++) + { + if (number_buffer[i] == decimal_point) + { + output_pointer[i] = '.'; + continue; + } + + output_pointer[i] = number_buffer[i]; + } + output_pointer[i] = '\0'; + + output_buffer->offset += (size_t)length; + + return true; +} + +/* parse 4 digit hexadecimal number */ +static unsigned parse_hex4(const unsigned char * const input) +{ + unsigned int h = 0; + size_t i = 0; + + for (i = 0; i < 4; i++) + { + /* parse digit */ + if ((input[i] >= '0') && (input[i] <= '9')) + { + h += (unsigned int) input[i] - '0'; + } + else if ((input[i] >= 'A') && (input[i] <= 'F')) + { + h += (unsigned int) 10 + input[i] - 'A'; + } + else if ((input[i] >= 'a') && (input[i] <= 'f')) + { + h += (unsigned int) 10 + input[i] - 'a'; + } + else /* invalid */ + { + return 0; + } + + if (i < 3) + { + /* shift left to make place for the next nibble */ + h = h << 4; + } + } + + return h; +} + +/* converts a UTF-16 literal to UTF-8 + * A literal can be one or two sequences of the form \uXXXX */ +static unsigned char utf16_literal_to_utf8(const unsigned char * const input_pointer, const unsigned char * const input_end, unsigned char **output_pointer) +{ + long unsigned int codepoint = 0; + unsigned int first_code = 0; + const unsigned char *first_sequence = input_pointer; + unsigned char utf8_length = 0; + unsigned char utf8_position = 0; + unsigned char sequence_length = 0; + unsigned char first_byte_mark = 0; + + if ((input_end - first_sequence) < 6) + { + /* input ends unexpectedly */ + goto fail; + } + + /* get the first utf16 sequence */ + first_code = parse_hex4(first_sequence + 2); + + /* check that the code is valid */ + if (((first_code >= 0xDC00) && (first_code <= 0xDFFF))) + { + goto fail; + } + + /* UTF16 surrogate pair */ + if ((first_code >= 0xD800) && (first_code <= 0xDBFF)) + { + const unsigned char *second_sequence = first_sequence + 6; + unsigned int second_code = 0; + sequence_length = 12; /* \uXXXX\uXXXX */ + + if ((input_end - second_sequence) < 6) + { + /* input ends unexpectedly */ + goto fail; + } + + if ((second_sequence[0] != '\\') || (second_sequence[1] != 'u')) + { + /* missing second half of the surrogate pair */ + goto fail; + } + + /* get the second utf16 sequence */ + second_code = parse_hex4(second_sequence + 2); + /* check that the code is valid */ + if ((second_code < 0xDC00) || (second_code > 0xDFFF)) + { + /* invalid second half of the surrogate pair */ + goto fail; + } + + + /* calculate the unicode codepoint from the surrogate pair */ + codepoint = 0x10000 + (((first_code & 0x3FF) << 10) | (second_code & 0x3FF)); + } + else + { + sequence_length = 6; /* \uXXXX */ + codepoint = first_code; + } + + /* encode as UTF-8 + * takes at maximum 4 bytes to encode: + * 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */ + if (codepoint < 0x80) + { + /* normal ascii, encoding 0xxxxxxx */ + utf8_length = 1; + } + else if (codepoint < 0x800) + { + /* two bytes, encoding 110xxxxx 10xxxxxx */ + utf8_length = 2; + first_byte_mark = 0xC0; /* 11000000 */ + } + else if (codepoint < 0x10000) + { + /* three bytes, encoding 1110xxxx 10xxxxxx 10xxxxxx */ + utf8_length = 3; + first_byte_mark = 0xE0; /* 11100000 */ + } + else if (codepoint <= 0x10FFFF) + { + /* four bytes, encoding 1110xxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + utf8_length = 4; + first_byte_mark = 0xF0; /* 11110000 */ + } + else + { + /* invalid unicode codepoint */ + goto fail; + } + + /* encode as utf8 */ + for (utf8_position = (unsigned char)(utf8_length - 1); utf8_position > 0; utf8_position--) + { + /* 10xxxxxx */ + (*output_pointer)[utf8_position] = (unsigned char)((codepoint | 0x80) & 0xBF); + codepoint >>= 6; + } + /* encode first byte */ + if (utf8_length > 1) + { + (*output_pointer)[0] = (unsigned char)((codepoint | first_byte_mark) & 0xFF); + } + else + { + (*output_pointer)[0] = (unsigned char)(codepoint & 0x7F); + } + + *output_pointer += utf8_length; + + return sequence_length; + +fail: + return 0; +} + +/* Parse the input text into an unescaped cinput, and populate item. */ +static cJSON_bool parse_string(cJSON * const item, parse_buffer * const input_buffer) +{ + const unsigned char *input_pointer = buffer_at_offset(input_buffer) + 1; + const unsigned char *input_end = buffer_at_offset(input_buffer) + 1; + unsigned char *output_pointer = NULL; + unsigned char *output = NULL; + + /* not a string */ + if (buffer_at_offset(input_buffer)[0] != '\"') + { + goto fail; + } + + { + /* calculate approximate size of the output (overestimate) */ + size_t allocation_length = 0; + size_t skipped_bytes = 0; + while (((size_t)(input_end - input_buffer->content) < input_buffer->length) && (*input_end != '\"')) + { + /* is escape sequence */ + if (input_end[0] == '\\') + { + if ((size_t)(input_end + 1 - input_buffer->content) >= input_buffer->length) + { + /* prevent buffer overflow when last input character is a backslash */ + goto fail; + } + skipped_bytes++; + input_end++; + } + input_end++; + } + if (((size_t)(input_end - input_buffer->content) >= input_buffer->length) || (*input_end != '\"')) + { + goto fail; /* string ended unexpectedly */ + } + + /* This is at most how much we need for the output */ + allocation_length = (size_t) (input_end - buffer_at_offset(input_buffer)) - skipped_bytes; + output = (unsigned char*)input_buffer->hooks.allocate(allocation_length + sizeof("")); + if (output == NULL) + { + goto fail; /* allocation failure */ + } + } + + output_pointer = output; + /* loop through the string literal */ + while (input_pointer < input_end) + { + if (*input_pointer != '\\') + { + *output_pointer++ = *input_pointer++; + } + /* escape sequence */ + else + { + unsigned char sequence_length = 2; + if ((input_end - input_pointer) < 1) + { + goto fail; + } + + switch (input_pointer[1]) + { + case 'b': + *output_pointer++ = '\b'; + break; + case 'f': + *output_pointer++ = '\f'; + break; + case 'n': + *output_pointer++ = '\n'; + break; + case 'r': + *output_pointer++ = '\r'; + break; + case 't': + *output_pointer++ = '\t'; + break; + case '\"': + case '\\': + case '/': + *output_pointer++ = input_pointer[1]; + break; + + /* UTF-16 literal */ + case 'u': + sequence_length = utf16_literal_to_utf8(input_pointer, input_end, &output_pointer); + if (sequence_length == 0) + { + /* failed to convert UTF16-literal to UTF-8 */ + goto fail; + } + break; + + default: + goto fail; + } + input_pointer += sequence_length; + } + } + + /* zero terminate the output */ + *output_pointer = '\0'; + + item->type = cJSON_String; + item->valuestring = (char*)output; + + input_buffer->offset = (size_t) (input_end - input_buffer->content); + input_buffer->offset++; + + return true; + +fail: + if (output != NULL) + { + input_buffer->hooks.deallocate(output); + } + + if (input_pointer != NULL) + { + input_buffer->offset = (size_t)(input_pointer - input_buffer->content); + } + + return false; +} + +/* Render the cstring provided to an escaped version that can be printed. */ +static cJSON_bool print_string_ptr(const unsigned char * const input, printbuffer * const output_buffer) +{ + const unsigned char *input_pointer = NULL; + unsigned char *output = NULL; + unsigned char *output_pointer = NULL; + size_t output_length = 0; + /* numbers of additional characters needed for escaping */ + size_t escape_characters = 0; + + if (output_buffer == NULL) + { + return false; + } + + /* empty string */ + if (input == NULL) + { + output = ensure(output_buffer, sizeof("\"\"")); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "\"\""); + + return true; + } + + /* set "flag" to 1 if something needs to be escaped */ + for (input_pointer = input; *input_pointer; input_pointer++) + { + switch (*input_pointer) + { + case '\"': + case '\\': + case '\b': + case '\f': + case '\n': + case '\r': + case '\t': + /* one character escape sequence */ + escape_characters++; + break; + default: + if (*input_pointer < 32) + { + /* UTF-16 escape sequence uXXXX */ + escape_characters += 5; + } + break; + } + } + output_length = (size_t)(input_pointer - input) + escape_characters; + + output = ensure(output_buffer, output_length + sizeof("\"\"")); + if (output == NULL) + { + return false; + } + + /* no characters have to be escaped */ + if (escape_characters == 0) + { + output[0] = '\"'; + memcpy(output + 1, input, output_length); + output[output_length + 1] = '\"'; + output[output_length + 2] = '\0'; + + return true; + } + + output[0] = '\"'; + output_pointer = output + 1; + /* copy the string */ + for (input_pointer = input; *input_pointer != '\0'; (void)input_pointer++, output_pointer++) + { + if ((*input_pointer > 31) && (*input_pointer != '\"') && (*input_pointer != '\\')) + { + /* normal character, copy */ + *output_pointer = *input_pointer; + } + else + { + /* character needs to be escaped */ + *output_pointer++ = '\\'; + switch (*input_pointer) + { + case '\\': + *output_pointer = '\\'; + break; + case '\"': + *output_pointer = '\"'; + break; + case '\b': + *output_pointer = 'b'; + break; + case '\f': + *output_pointer = 'f'; + break; + case '\n': + *output_pointer = 'n'; + break; + case '\r': + *output_pointer = 'r'; + break; + case '\t': + *output_pointer = 't'; + break; + default: + /* escape and print as unicode codepoint */ + sprintf((char*)output_pointer, "u%04x", *input_pointer); + output_pointer += 4; + break; + } + } + } + output[output_length + 1] = '\"'; + output[output_length + 2] = '\0'; + + return true; +} + +/* Invoke print_string_ptr (which is useful) on an item. */ +static cJSON_bool print_string(const cJSON * const item, printbuffer * const p) +{ + return print_string_ptr((unsigned char*)item->valuestring, p); +} + +/* Predeclare these prototypes. */ +static cJSON_bool parse_value(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_value(const cJSON * const item, printbuffer * const output_buffer); +static cJSON_bool parse_array(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_array(const cJSON * const item, printbuffer * const output_buffer); +static cJSON_bool parse_object(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_object(const cJSON * const item, printbuffer * const output_buffer); + +/* Utility to jump whitespace and cr/lf */ +static parse_buffer *buffer_skip_whitespace(parse_buffer * const buffer) +{ + if ((buffer == NULL) || (buffer->content == NULL)) + { + return NULL; + } + + if (cannot_access_at_index(buffer, 0)) + { + return buffer; + } + + while (can_access_at_index(buffer, 0) && (buffer_at_offset(buffer)[0] <= 32)) + { + buffer->offset++; + } + + if (buffer->offset == buffer->length) + { + buffer->offset--; + } + + return buffer; +} + +/* skip the UTF-8 BOM (byte order mark) if it is at the beginning of a buffer */ +static parse_buffer *skip_utf8_bom(parse_buffer * const buffer) +{ + if ((buffer == NULL) || (buffer->content == NULL) || (buffer->offset != 0)) + { + return NULL; + } + + if (can_access_at_index(buffer, 4) && (strncmp((const char*)buffer_at_offset(buffer), "\xEF\xBB\xBF", 3) == 0)) + { + buffer->offset += 3; + } + + return buffer; +} + +CJSON_PUBLIC(cJSON *) cJSON_ParseWithOpts(const char *value, const char **return_parse_end, cJSON_bool require_null_terminated) +{ + size_t buffer_length; + + if (NULL == value) + { + return NULL; + } + + /* Adding null character size due to require_null_terminated. */ + buffer_length = strlen(value) + sizeof(""); + + return cJSON_ParseWithLengthOpts(value, buffer_length, return_parse_end, require_null_terminated); +} + +/* Parse an object - create a new root, and populate. */ +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLengthOpts(const char *value, size_t buffer_length, const char **return_parse_end, cJSON_bool require_null_terminated) +{ + parse_buffer buffer = { 0, 0, 0, 0, { 0, 0, 0 } }; + cJSON *item = NULL; + + /* reset error position */ + global_error.json = NULL; + global_error.position = 0; + + if (value == NULL || 0 == buffer_length) + { + goto fail; + } + + buffer.content = (const unsigned char*)value; + buffer.length = buffer_length; + buffer.offset = 0; + buffer.hooks = global_hooks; + + item = cJSON_New_Item(&global_hooks); + if (item == NULL) /* memory fail */ + { + goto fail; + } + + if (!parse_value(item, buffer_skip_whitespace(skip_utf8_bom(&buffer)))) + { + /* parse failure. ep is set. */ + goto fail; + } + + /* if we require null-terminated JSON without appended garbage, skip and then check for a null terminator */ + if (require_null_terminated) + { + buffer_skip_whitespace(&buffer); + if ((buffer.offset >= buffer.length) || buffer_at_offset(&buffer)[0] != '\0') + { + goto fail; + } + } + if (return_parse_end) + { + *return_parse_end = (const char*)buffer_at_offset(&buffer); + } + + return item; + +fail: + if (item != NULL) + { + cJSON_Delete(item); + } + + if (value != NULL) + { + error local_error; + local_error.json = (const unsigned char*)value; + local_error.position = 0; + + if (buffer.offset < buffer.length) + { + local_error.position = buffer.offset; + } + else if (buffer.length > 0) + { + local_error.position = buffer.length - 1; + } + + if (return_parse_end != NULL) + { + *return_parse_end = (const char*)local_error.json + local_error.position; + } + + global_error = local_error; + } + + return NULL; +} + +/* Default options for cJSON_Parse */ +CJSON_PUBLIC(cJSON *) cJSON_Parse(const char *value) +{ + return cJSON_ParseWithOpts(value, 0, 0); +} + +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLength(const char *value, size_t buffer_length) +{ + return cJSON_ParseWithLengthOpts(value, buffer_length, 0, 0); +} + +#define cjson_min(a, b) (((a) < (b)) ? (a) : (b)) + +static unsigned char *print(const cJSON * const item, cJSON_bool format, const internal_hooks * const hooks) +{ + static const size_t default_buffer_size = 256; + printbuffer buffer[1]; + unsigned char *printed = NULL; + + memset(buffer, 0, sizeof(buffer)); + + /* create buffer */ + buffer->buffer = (unsigned char*) hooks->allocate(default_buffer_size); + buffer->length = default_buffer_size; + buffer->format = format; + buffer->hooks = *hooks; + if (buffer->buffer == NULL) + { + goto fail; + } + + /* print the value */ + if (!print_value(item, buffer)) + { + goto fail; + } + update_offset(buffer); + + /* check if reallocate is available */ + if (hooks->reallocate != NULL) + { + printed = (unsigned char*) hooks->reallocate(buffer->buffer, buffer->offset + 1); + if (printed == NULL) { + goto fail; + } + buffer->buffer = NULL; + } + else /* otherwise copy the JSON over to a new buffer */ + { + printed = (unsigned char*) hooks->allocate(buffer->offset + 1); + if (printed == NULL) + { + goto fail; + } + memcpy(printed, buffer->buffer, cjson_min(buffer->length, buffer->offset + 1)); + printed[buffer->offset] = '\0'; /* just to be sure */ + + /* free the buffer */ + hooks->deallocate(buffer->buffer); + } + + return printed; + +fail: + if (buffer->buffer != NULL) + { + hooks->deallocate(buffer->buffer); + } + + if (printed != NULL) + { + hooks->deallocate(printed); + } + + return NULL; +} + +/* Render a cJSON item/entity/structure to text. */ +CJSON_PUBLIC(char *) cJSON_Print(const cJSON *item) +{ + return (char*)print(item, true, &global_hooks); +} + +CJSON_PUBLIC(char *) cJSON_PrintUnformatted(const cJSON *item) +{ + return (char*)print(item, false, &global_hooks); +} + +CJSON_PUBLIC(char *) cJSON_PrintBuffered(const cJSON *item, int prebuffer, cJSON_bool fmt) +{ + printbuffer p = { 0, 0, 0, 0, 0, 0, { 0, 0, 0 } }; + + if (prebuffer < 0) + { + return NULL; + } + + p.buffer = (unsigned char*)global_hooks.allocate((size_t)prebuffer); + if (!p.buffer) + { + return NULL; + } + + p.length = (size_t)prebuffer; + p.offset = 0; + p.noalloc = false; + p.format = fmt; + p.hooks = global_hooks; + + if (!print_value(item, &p)) + { + global_hooks.deallocate(p.buffer); + return NULL; + } + + return (char*)p.buffer; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_PrintPreallocated(cJSON *item, char *buffer, const int length, const cJSON_bool format) +{ + printbuffer p = { 0, 0, 0, 0, 0, 0, { 0, 0, 0 } }; + + if ((length < 0) || (buffer == NULL)) + { + return false; + } + + p.buffer = (unsigned char*)buffer; + p.length = (size_t)length; + p.offset = 0; + p.noalloc = true; + p.format = format; + p.hooks = global_hooks; + + return print_value(item, &p); +} + +/* Parser core - when encountering text, process appropriately. */ +static cJSON_bool parse_value(cJSON * const item, parse_buffer * const input_buffer) +{ + if ((input_buffer == NULL) || (input_buffer->content == NULL)) + { + return false; /* no input */ + } + + /* parse the different types of values */ + /* null */ + if (can_read(input_buffer, 4) && (strncmp((const char*)buffer_at_offset(input_buffer), "null", 4) == 0)) + { + item->type = cJSON_NULL; + input_buffer->offset += 4; + return true; + } + /* false */ + if (can_read(input_buffer, 5) && (strncmp((const char*)buffer_at_offset(input_buffer), "false", 5) == 0)) + { + item->type = cJSON_False; + input_buffer->offset += 5; + return true; + } + /* true */ + if (can_read(input_buffer, 4) && (strncmp((const char*)buffer_at_offset(input_buffer), "true", 4) == 0)) + { + item->type = cJSON_True; + item->valueint = 1; + input_buffer->offset += 4; + return true; + } + /* string */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '\"')) + { + return parse_string(item, input_buffer); + } + /* number */ + if (can_access_at_index(input_buffer, 0) && ((buffer_at_offset(input_buffer)[0] == '-') || ((buffer_at_offset(input_buffer)[0] >= '0') && (buffer_at_offset(input_buffer)[0] <= '9')))) + { + return parse_number(item, input_buffer); + } + /* array */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '[')) + { + return parse_array(item, input_buffer); + } + /* object */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '{')) + { + return parse_object(item, input_buffer); + } + + return false; +} + +/* Render a value to text. */ +static cJSON_bool print_value(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output = NULL; + + if ((item == NULL) || (output_buffer == NULL)) + { + return false; + } + + switch ((item->type) & 0xFF) + { + case cJSON_NULL: + output = ensure(output_buffer, 5); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "null"); + return true; + + case cJSON_False: + output = ensure(output_buffer, 6); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "false"); + return true; + + case cJSON_True: + output = ensure(output_buffer, 5); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "true"); + return true; + + case cJSON_Number: + return print_number(item, output_buffer); + + case cJSON_Raw: + { + size_t raw_length = 0; + if (item->valuestring == NULL) + { + return false; + } + + raw_length = strlen(item->valuestring) + sizeof(""); + output = ensure(output_buffer, raw_length); + if (output == NULL) + { + return false; + } + memcpy(output, item->valuestring, raw_length); + return true; + } + + case cJSON_String: + return print_string(item, output_buffer); + + case cJSON_Array: + return print_array(item, output_buffer); + + case cJSON_Object: + return print_object(item, output_buffer); + + default: + return false; + } +} + +/* Build an array from input text. */ +static cJSON_bool parse_array(cJSON * const item, parse_buffer * const input_buffer) +{ + cJSON *head = NULL; /* head of the linked list */ + cJSON *current_item = NULL; + + if (input_buffer->depth >= CJSON_NESTING_LIMIT) + { + return false; /* to deeply nested */ + } + input_buffer->depth++; + + if (buffer_at_offset(input_buffer)[0] != '[') + { + /* not an array */ + goto fail; + } + + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ']')) + { + /* empty array */ + goto success; + } + + /* check if we skipped to the end of the buffer */ + if (cannot_access_at_index(input_buffer, 0)) + { + input_buffer->offset--; + goto fail; + } + + /* step back to character in front of the first element */ + input_buffer->offset--; + /* loop through the comma separated array elements */ + do + { + /* allocate next item */ + cJSON *new_item = cJSON_New_Item(&(input_buffer->hooks)); + if (new_item == NULL) + { + goto fail; /* allocation failure */ + } + + /* attach next item to list */ + if (head == NULL) + { + /* start the linked list */ + current_item = head = new_item; + } + else + { + /* add to the end and advance */ + current_item->next = new_item; + new_item->prev = current_item; + current_item = new_item; + } + + /* parse next value */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_value(current_item, input_buffer)) + { + goto fail; /* failed to parse value */ + } + buffer_skip_whitespace(input_buffer); + } + while (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ',')); + + if (cannot_access_at_index(input_buffer, 0) || buffer_at_offset(input_buffer)[0] != ']') + { + goto fail; /* expected end of array */ + } + +success: + input_buffer->depth--; + + if (head != NULL) { + head->prev = current_item; + } + + item->type = cJSON_Array; + item->child = head; + + input_buffer->offset++; + + return true; + +fail: + if (head != NULL) + { + cJSON_Delete(head); + } + + return false; +} + +/* Render an array to text */ +static cJSON_bool print_array(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + size_t length = 0; + cJSON *current_element = item->child; + + if (output_buffer == NULL) + { + return false; + } + + /* Compose the output array. */ + /* opening square bracket */ + output_pointer = ensure(output_buffer, 1); + if (output_pointer == NULL) + { + return false; + } + + *output_pointer = '['; + output_buffer->offset++; + output_buffer->depth++; + + while (current_element != NULL) + { + if (!print_value(current_element, output_buffer)) + { + return false; + } + update_offset(output_buffer); + if (current_element->next) + { + length = (size_t) (output_buffer->format ? 2 : 1); + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ','; + if(output_buffer->format) + { + *output_pointer++ = ' '; + } + *output_pointer = '\0'; + output_buffer->offset += length; + } + current_element = current_element->next; + } + + output_pointer = ensure(output_buffer, 2); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ']'; + *output_pointer = '\0'; + output_buffer->depth--; + + return true; +} + +/* Build an object from the text. */ +static cJSON_bool parse_object(cJSON * const item, parse_buffer * const input_buffer) +{ + cJSON *head = NULL; /* linked list head */ + cJSON *current_item = NULL; + + if (input_buffer->depth >= CJSON_NESTING_LIMIT) + { + return false; /* to deeply nested */ + } + input_buffer->depth++; + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != '{')) + { + goto fail; /* not an object */ + } + + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '}')) + { + goto success; /* empty object */ + } + + /* check if we skipped to the end of the buffer */ + if (cannot_access_at_index(input_buffer, 0)) + { + input_buffer->offset--; + goto fail; + } + + /* step back to character in front of the first element */ + input_buffer->offset--; + /* loop through the comma separated array elements */ + do + { + /* allocate next item */ + cJSON *new_item = cJSON_New_Item(&(input_buffer->hooks)); + if (new_item == NULL) + { + goto fail; /* allocation failure */ + } + + /* attach next item to list */ + if (head == NULL) + { + /* start the linked list */ + current_item = head = new_item; + } + else + { + /* add to the end and advance */ + current_item->next = new_item; + new_item->prev = current_item; + current_item = new_item; + } + + /* parse the name of the child */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_string(current_item, input_buffer)) + { + goto fail; /* failed to parse name */ + } + buffer_skip_whitespace(input_buffer); + + /* swap valuestring and string, because we parsed the name */ + current_item->string = current_item->valuestring; + current_item->valuestring = NULL; + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != ':')) + { + goto fail; /* invalid object */ + } + + /* parse the value */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_value(current_item, input_buffer)) + { + goto fail; /* failed to parse value */ + } + buffer_skip_whitespace(input_buffer); + } + while (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ',')); + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != '}')) + { + goto fail; /* expected end of object */ + } + +success: + input_buffer->depth--; + + if (head != NULL) { + head->prev = current_item; + } + + item->type = cJSON_Object; + item->child = head; + + input_buffer->offset++; + return true; + +fail: + if (head != NULL) + { + cJSON_Delete(head); + } + + return false; +} + +/* Render an object to text. */ +static cJSON_bool print_object(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + size_t length = 0; + cJSON *current_item = item->child; + + if (output_buffer == NULL) + { + return false; + } + + /* Compose the output: */ + length = (size_t) (output_buffer->format ? 2 : 1); /* fmt: {\n */ + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + + *output_pointer++ = '{'; + output_buffer->depth++; + if (output_buffer->format) + { + *output_pointer++ = '\n'; + } + output_buffer->offset += length; + + while (current_item) + { + if (output_buffer->format) + { + size_t i; + output_pointer = ensure(output_buffer, output_buffer->depth); + if (output_pointer == NULL) + { + return false; + } + for (i = 0; i < output_buffer->depth; i++) + { + *output_pointer++ = '\t'; + } + output_buffer->offset += output_buffer->depth; + } + + /* print key */ + if (!print_string_ptr((unsigned char*)current_item->string, output_buffer)) + { + return false; + } + update_offset(output_buffer); + + length = (size_t) (output_buffer->format ? 2 : 1); + output_pointer = ensure(output_buffer, length); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ':'; + if (output_buffer->format) + { + *output_pointer++ = '\t'; + } + output_buffer->offset += length; + + /* print value */ + if (!print_value(current_item, output_buffer)) + { + return false; + } + update_offset(output_buffer); + + /* print comma if not last */ + length = ((size_t)(output_buffer->format ? 1 : 0) + (size_t)(current_item->next ? 1 : 0)); + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + if (current_item->next) + { + *output_pointer++ = ','; + } + + if (output_buffer->format) + { + *output_pointer++ = '\n'; + } + *output_pointer = '\0'; + output_buffer->offset += length; + + current_item = current_item->next; + } + + output_pointer = ensure(output_buffer, output_buffer->format ? (output_buffer->depth + 1) : 2); + if (output_pointer == NULL) + { + return false; + } + if (output_buffer->format) + { + size_t i; + for (i = 0; i < (output_buffer->depth - 1); i++) + { + *output_pointer++ = '\t'; + } + } + *output_pointer++ = '}'; + *output_pointer = '\0'; + output_buffer->depth--; + + return true; +} + +/* Get Array size/item / object item. */ +CJSON_PUBLIC(int) cJSON_GetArraySize(const cJSON *array) +{ + cJSON *child = NULL; + size_t size = 0; + + if (array == NULL) + { + return 0; + } + + child = array->child; + + while(child != NULL) + { + size++; + child = child->next; + } + + /* FIXME: Can overflow here. Cannot be fixed without breaking the API */ + + return (int)size; +} + +static cJSON* get_array_item(const cJSON *array, size_t index) +{ + cJSON *current_child = NULL; + + if (array == NULL) + { + return NULL; + } + + current_child = array->child; + while ((current_child != NULL) && (index > 0)) + { + index--; + current_child = current_child->next; + } + + return current_child; +} + +CJSON_PUBLIC(cJSON *) cJSON_GetArrayItem(const cJSON *array, int index) +{ + if (index < 0) + { + return NULL; + } + + return get_array_item(array, (size_t)index); +} + +static cJSON *get_object_item(const cJSON * const object, const char * const name, const cJSON_bool case_sensitive) +{ + cJSON *current_element = NULL; + + if ((object == NULL) || (name == NULL)) + { + return NULL; + } + + current_element = object->child; + if (case_sensitive) + { + while ((current_element != NULL) && (current_element->string != NULL) && (strcmp(name, current_element->string) != 0)) + { + current_element = current_element->next; + } + } + else + { + while ((current_element != NULL) && (case_insensitive_strcmp((const unsigned char*)name, (const unsigned char*)(current_element->string)) != 0)) + { + current_element = current_element->next; + } + } + + if ((current_element == NULL) || (current_element->string == NULL)) { + return NULL; + } + + return current_element; +} + +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItem(const cJSON * const object, const char * const string) +{ + return get_object_item(object, string, false); +} + +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItemCaseSensitive(const cJSON * const object, const char * const string) +{ + return get_object_item(object, string, true); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_HasObjectItem(const cJSON *object, const char *string) +{ + return cJSON_GetObjectItem(object, string) ? 1 : 0; +} + +/* Utility for array list handling. */ +static void suffix_object(cJSON *prev, cJSON *item) +{ + prev->next = item; + item->prev = prev; +} + +/* Utility for handling references. */ +static cJSON *create_reference(const cJSON *item, const internal_hooks * const hooks) +{ + cJSON *reference = NULL; + if (item == NULL) + { + return NULL; + } + + reference = cJSON_New_Item(hooks); + if (reference == NULL) + { + return NULL; + } + + memcpy(reference, item, sizeof(cJSON)); + reference->string = NULL; + reference->type |= cJSON_IsReference; + reference->next = reference->prev = NULL; + return reference; +} + +static cJSON_bool add_item_to_array(cJSON *array, cJSON *item) +{ + cJSON *child = NULL; + + if ((item == NULL) || (array == NULL) || (array == item)) + { + return false; + } + + child = array->child; + /* + * To find the last item in array quickly, we use prev in array + */ + if (child == NULL) + { + /* list is empty, start new one */ + array->child = item; + item->prev = item; + item->next = NULL; + } + else + { + /* append to the end */ + if (child->prev) + { + suffix_object(child->prev, item); + array->child->prev = item; + } + } + + return true; +} + +/* Add item to array/object. */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToArray(cJSON *array, cJSON *item) +{ + return add_item_to_array(array, item); +} + +#if defined(__clang__) || (defined(__GNUC__) && ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 5)))) + #pragma GCC diagnostic push +#endif +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Wcast-qual" +#endif +/* helper function to cast away const */ +static void* cast_away_const(const void* string) +{ + return (void*)string; +} +#if defined(__clang__) || (defined(__GNUC__) && ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 5)))) + #pragma GCC diagnostic pop +#endif + + +static cJSON_bool add_item_to_object(cJSON * const object, const char * const string, cJSON * const item, const internal_hooks * const hooks, const cJSON_bool constant_key) +{ + char *new_key = NULL; + int new_type = cJSON_Invalid; + + if ((object == NULL) || (string == NULL) || (item == NULL) || (object == item)) + { + return false; + } + + if (constant_key) + { + new_key = (char*)cast_away_const(string); + new_type = item->type | cJSON_StringIsConst; + } + else + { + new_key = (char*)cJSON_strdup((const unsigned char*)string, hooks); + if (new_key == NULL) + { + return false; + } + + new_type = item->type & ~cJSON_StringIsConst; + } + + if (!(item->type & cJSON_StringIsConst) && (item->string != NULL)) + { + hooks->deallocate(item->string); + } + + item->string = new_key; + item->type = new_type; + + return add_item_to_array(object, item); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObject(cJSON *object, const char *string, cJSON *item) +{ + return add_item_to_object(object, string, item, &global_hooks, false); +} + +/* Add an item to an object with constant string as key */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObjectCS(cJSON *object, const char *string, cJSON *item) +{ + return add_item_to_object(object, string, item, &global_hooks, true); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item) +{ + if (array == NULL) + { + return false; + } + + return add_item_to_array(array, create_reference(item, &global_hooks)); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToObject(cJSON *object, const char *string, cJSON *item) +{ + if ((object == NULL) || (string == NULL)) + { + return false; + } + + return add_item_to_object(object, string, create_reference(item, &global_hooks), &global_hooks, false); +} + +CJSON_PUBLIC(cJSON*) cJSON_AddNullToObject(cJSON * const object, const char * const name) +{ + cJSON *null = cJSON_CreateNull(); + if (add_item_to_object(object, name, null, &global_hooks, false)) + { + return null; + } + + cJSON_Delete(null); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddTrueToObject(cJSON * const object, const char * const name) +{ + cJSON *true_item = cJSON_CreateTrue(); + if (add_item_to_object(object, name, true_item, &global_hooks, false)) + { + return true_item; + } + + cJSON_Delete(true_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddFalseToObject(cJSON * const object, const char * const name) +{ + cJSON *false_item = cJSON_CreateFalse(); + if (add_item_to_object(object, name, false_item, &global_hooks, false)) + { + return false_item; + } + + cJSON_Delete(false_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddBoolToObject(cJSON * const object, const char * const name, const cJSON_bool boolean) +{ + cJSON *bool_item = cJSON_CreateBool(boolean); + if (add_item_to_object(object, name, bool_item, &global_hooks, false)) + { + return bool_item; + } + + cJSON_Delete(bool_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddNumberToObject(cJSON * const object, const char * const name, const double number) +{ + cJSON *number_item = cJSON_CreateNumber(number); + if (add_item_to_object(object, name, number_item, &global_hooks, false)) + { + return number_item; + } + + cJSON_Delete(number_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddStringToObject(cJSON * const object, const char * const name, const char * const string) +{ + cJSON *string_item = cJSON_CreateString(string); + if (add_item_to_object(object, name, string_item, &global_hooks, false)) + { + return string_item; + } + + cJSON_Delete(string_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddRawToObject(cJSON * const object, const char * const name, const char * const raw) +{ + cJSON *raw_item = cJSON_CreateRaw(raw); + if (add_item_to_object(object, name, raw_item, &global_hooks, false)) + { + return raw_item; + } + + cJSON_Delete(raw_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddObjectToObject(cJSON * const object, const char * const name) +{ + cJSON *object_item = cJSON_CreateObject(); + if (add_item_to_object(object, name, object_item, &global_hooks, false)) + { + return object_item; + } + + cJSON_Delete(object_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddArrayToObject(cJSON * const object, const char * const name) +{ + cJSON *array = cJSON_CreateArray(); + if (add_item_to_object(object, name, array, &global_hooks, false)) + { + return array; + } + + cJSON_Delete(array); + return NULL; +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemViaPointer(cJSON *parent, cJSON * const item) +{ + if ((parent == NULL) || (item == NULL)) + { + return NULL; + } + + if (item != parent->child) + { + /* not the first element */ + item->prev->next = item->next; + } + if (item->next != NULL) + { + /* not the last element */ + item->next->prev = item->prev; + } + + if (item == parent->child) + { + /* first element */ + parent->child = item->next; + } + else if (item->next == NULL) + { + /* last element */ + parent->child->prev = item->prev; + } + + /* make sure the detached item doesn't point anywhere anymore */ + item->prev = NULL; + item->next = NULL; + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromArray(cJSON *array, int which) +{ + if (which < 0) + { + return NULL; + } + + return cJSON_DetachItemViaPointer(array, get_array_item(array, (size_t)which)); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromArray(cJSON *array, int which) +{ + cJSON_Delete(cJSON_DetachItemFromArray(array, which)); +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObject(cJSON *object, const char *string) +{ + cJSON *to_detach = cJSON_GetObjectItem(object, string); + + return cJSON_DetachItemViaPointer(object, to_detach); +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObjectCaseSensitive(cJSON *object, const char *string) +{ + cJSON *to_detach = cJSON_GetObjectItemCaseSensitive(object, string); + + return cJSON_DetachItemViaPointer(object, to_detach); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromObject(cJSON *object, const char *string) +{ + cJSON_Delete(cJSON_DetachItemFromObject(object, string)); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromObjectCaseSensitive(cJSON *object, const char *string) +{ + cJSON_Delete(cJSON_DetachItemFromObjectCaseSensitive(object, string)); +} + +/* Replace array/object items with new ones. */ +CJSON_PUBLIC(cJSON_bool) cJSON_InsertItemInArray(cJSON *array, int which, cJSON *newitem) +{ + cJSON *after_inserted = NULL; + + if (which < 0) + { + return false; + } + + after_inserted = get_array_item(array, (size_t)which); + if (after_inserted == NULL) + { + return add_item_to_array(array, newitem); + } + + newitem->next = after_inserted; + newitem->prev = after_inserted->prev; + after_inserted->prev = newitem; + if (after_inserted == array->child) + { + array->child = newitem; + } + else + { + newitem->prev->next = newitem; + } + return true; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemViaPointer(cJSON * const parent, cJSON * const item, cJSON * replacement) +{ + if ((parent == NULL) || (replacement == NULL) || (item == NULL)) + { + return false; + } + + if (replacement == item) + { + return true; + } + + replacement->next = item->next; + replacement->prev = item->prev; + + if (replacement->next != NULL) + { + replacement->next->prev = replacement; + } + if (parent->child == item) + { + if (parent->child->prev == parent->child) + { + replacement->prev = replacement; + } + parent->child = replacement; + } + else + { /* + * To find the last item in array quickly, we use prev in array. + * We can't modify the last item's next pointer where this item was the parent's child + */ + if (replacement->prev != NULL) + { + replacement->prev->next = replacement; + } + if (replacement->next == NULL) + { + parent->child->prev = replacement; + } + } + + item->next = NULL; + item->prev = NULL; + cJSON_Delete(item); + + return true; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInArray(cJSON *array, int which, cJSON *newitem) +{ + if (which < 0) + { + return false; + } + + return cJSON_ReplaceItemViaPointer(array, get_array_item(array, (size_t)which), newitem); +} + +static cJSON_bool replace_item_in_object(cJSON *object, const char *string, cJSON *replacement, cJSON_bool case_sensitive) +{ + if ((replacement == NULL) || (string == NULL)) + { + return false; + } + + /* replace the name in the replacement */ + if (!(replacement->type & cJSON_StringIsConst) && (replacement->string != NULL)) + { + cJSON_free(replacement->string); + } + replacement->string = (char*)cJSON_strdup((const unsigned char*)string, &global_hooks); + replacement->type &= ~cJSON_StringIsConst; + + return cJSON_ReplaceItemViaPointer(object, get_object_item(object, string, case_sensitive), replacement); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObject(cJSON *object, const char *string, cJSON *newitem) +{ + return replace_item_in_object(object, string, newitem, false); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObjectCaseSensitive(cJSON *object, const char *string, cJSON *newitem) +{ + return replace_item_in_object(object, string, newitem, true); +} + +/* Create basic types: */ +CJSON_PUBLIC(cJSON *) cJSON_CreateNull(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_NULL; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateTrue(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_True; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateFalse(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_False; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateBool(cJSON_bool boolean) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = boolean ? cJSON_True : cJSON_False; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateNumber(double num) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Number; + item->valuedouble = num; + + /* use saturation in case of overflow */ + if (num >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (num <= (double)INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)num; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateString(const char *string) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_String; + item->valuestring = (char*)cJSON_strdup((const unsigned char*)string, &global_hooks); + if(!item->valuestring) + { + cJSON_Delete(item); + return NULL; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateStringReference(const char *string) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) + { + item->type = cJSON_String | cJSON_IsReference; + item->valuestring = (char*)cast_away_const(string); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateObjectReference(const cJSON *child) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) { + item->type = cJSON_Object | cJSON_IsReference; + item->child = (cJSON*)cast_away_const(child); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateArrayReference(const cJSON *child) { + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) { + item->type = cJSON_Array | cJSON_IsReference; + item->child = (cJSON*)cast_away_const(child); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateRaw(const char *raw) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Raw; + item->valuestring = (char*)cJSON_strdup((const unsigned char*)raw, &global_hooks); + if(!item->valuestring) + { + cJSON_Delete(item); + return NULL; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateArray(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type=cJSON_Array; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateObject(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item) + { + item->type = cJSON_Object; + } + + return item; +} + +/* Create Arrays: */ +CJSON_PUBLIC(cJSON *) cJSON_CreateIntArray(const int *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber(numbers[i]); + if (!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateFloatArray(const float *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber((double)numbers[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateDoubleArray(const double *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber(numbers[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateStringArray(const char *const *strings, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (strings == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for (i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateString(strings[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p,n); + } + p = n; + } + + if (a && a->child) { + a->child->prev = n; + } + + return a; +} + +/* Duplication */ +CJSON_PUBLIC(cJSON *) cJSON_Duplicate(const cJSON *item, cJSON_bool recurse) +{ + cJSON *newitem = NULL; + cJSON *child = NULL; + cJSON *next = NULL; + cJSON *newchild = NULL; + + /* Bail on bad ptr */ + if (!item) + { + goto fail; + } + /* Create new item */ + newitem = cJSON_New_Item(&global_hooks); + if (!newitem) + { + goto fail; + } + /* Copy over all vars */ + newitem->type = item->type & (~cJSON_IsReference); + newitem->valueint = item->valueint; + newitem->valuedouble = item->valuedouble; + if (item->valuestring) + { + newitem->valuestring = (char*)cJSON_strdup((unsigned char*)item->valuestring, &global_hooks); + if (!newitem->valuestring) + { + goto fail; + } + } + if (item->string) + { + newitem->string = (item->type&cJSON_StringIsConst) ? item->string : (char*)cJSON_strdup((unsigned char*)item->string, &global_hooks); + if (!newitem->string) + { + goto fail; + } + } + /* If non-recursive, then we're done! */ + if (!recurse) + { + return newitem; + } + /* Walk the ->next chain for the child. */ + child = item->child; + while (child != NULL) + { + newchild = cJSON_Duplicate(child, true); /* Duplicate (with recurse) each item in the ->next chain */ + if (!newchild) + { + goto fail; + } + if (next != NULL) + { + /* If newitem->child already set, then crosswire ->prev and ->next and move on */ + next->next = newchild; + newchild->prev = next; + next = newchild; + } + else + { + /* Set newitem->child and move to it */ + newitem->child = newchild; + next = newchild; + } + child = child->next; + } + if (newitem && newitem->child) + { + newitem->child->prev = newchild; + } + + return newitem; + +fail: + if (newitem != NULL) + { + cJSON_Delete(newitem); + } + + return NULL; +} + +static void skip_oneline_comment(char **input) +{ + *input += static_strlen("//"); + + for (; (*input)[0] != '\0'; ++(*input)) + { + if ((*input)[0] == '\n') { + *input += static_strlen("\n"); + return; + } + } +} + +static void skip_multiline_comment(char **input) +{ + *input += static_strlen("/*"); + + for (; (*input)[0] != '\0'; ++(*input)) + { + if (((*input)[0] == '*') && ((*input)[1] == '/')) + { + *input += static_strlen("*/"); + return; + } + } +} + +static void minify_string(char **input, char **output) { + (*output)[0] = (*input)[0]; + *input += static_strlen("\""); + *output += static_strlen("\""); + + + for (; (*input)[0] != '\0'; (void)++(*input), ++(*output)) { + (*output)[0] = (*input)[0]; + + if ((*input)[0] == '\"') { + (*output)[0] = '\"'; + *input += static_strlen("\""); + *output += static_strlen("\""); + return; + } else if (((*input)[0] == '\\') && ((*input)[1] == '\"')) { + (*output)[1] = (*input)[1]; + *input += static_strlen("\""); + *output += static_strlen("\""); + } + } +} + +CJSON_PUBLIC(void) cJSON_Minify(char *json) +{ + char *into = json; + + if (json == NULL) + { + return; + } + + while (json[0] != '\0') + { + switch (json[0]) + { + case ' ': + case '\t': + case '\r': + case '\n': + json++; + break; + + case '/': + if (json[1] == '/') + { + skip_oneline_comment(&json); + } + else if (json[1] == '*') + { + skip_multiline_comment(&json); + } else { + json++; + } + break; + + case '\"': + minify_string(&json, (char**)&into); + break; + + default: + into[0] = json[0]; + json++; + into++; + } + } + + /* and null-terminate. */ + *into = '\0'; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsInvalid(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Invalid; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsFalse(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_False; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsTrue(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xff) == cJSON_True; +} + + +CJSON_PUBLIC(cJSON_bool) cJSON_IsBool(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & (cJSON_True | cJSON_False)) != 0; +} +CJSON_PUBLIC(cJSON_bool) cJSON_IsNull(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_NULL; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsNumber(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Number; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsString(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_String; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsArray(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Array; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsObject(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Object; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsRaw(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Raw; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_Compare(const cJSON * const a, const cJSON * const b, const cJSON_bool case_sensitive) +{ + if ((a == NULL) || (b == NULL) || ((a->type & 0xFF) != (b->type & 0xFF)) || cJSON_IsInvalid(a)) + { + return false; + } + + /* check if type is valid */ + switch (a->type & 0xFF) + { + case cJSON_False: + case cJSON_True: + case cJSON_NULL: + case cJSON_Number: + case cJSON_String: + case cJSON_Raw: + case cJSON_Array: + case cJSON_Object: + break; + + default: + return false; + } + + /* identical objects are equal */ + if (a == b) + { + return true; + } + + switch (a->type & 0xFF) + { + /* in these cases and equal type is enough */ + case cJSON_False: + case cJSON_True: + case cJSON_NULL: + return true; + + case cJSON_Number: + if (compare_double(a->valuedouble, b->valuedouble)) + { + return true; + } + return false; + + case cJSON_String: + case cJSON_Raw: + if ((a->valuestring == NULL) || (b->valuestring == NULL)) + { + return false; + } + if (strcmp(a->valuestring, b->valuestring) == 0) + { + return true; + } + + return false; + + case cJSON_Array: + { + cJSON *a_element = a->child; + cJSON *b_element = b->child; + + for (; (a_element != NULL) && (b_element != NULL);) + { + if (!cJSON_Compare(a_element, b_element, case_sensitive)) + { + return false; + } + + a_element = a_element->next; + b_element = b_element->next; + } + + /* one of the arrays is longer than the other */ + if (a_element != b_element) { + return false; + } + + return true; + } + + case cJSON_Object: + { + cJSON *a_element = NULL; + cJSON *b_element = NULL; + cJSON_ArrayForEach(a_element, a) + { + /* TODO This has O(n^2) runtime, which is horrible! */ + b_element = get_object_item(b, a_element->string, case_sensitive); + if (b_element == NULL) + { + return false; + } + + if (!cJSON_Compare(a_element, b_element, case_sensitive)) + { + return false; + } + } + + /* doing this twice, once on a and b to prevent true comparison if a subset of b + * TODO: Do this the proper way, this is just a fix for now */ + cJSON_ArrayForEach(b_element, b) + { + a_element = get_object_item(a, b_element->string, case_sensitive); + if (a_element == NULL) + { + return false; + } + + if (!cJSON_Compare(b_element, a_element, case_sensitive)) + { + return false; + } + } + + return true; + } + + default: + return false; + } +} + +CJSON_PUBLIC(void *) cJSON_malloc(size_t size) +{ + return global_hooks.allocate(size); +} + +CJSON_PUBLIC(void) cJSON_free(void *object) +{ + global_hooks.deallocate(object); +} diff --git a/cJSON.h b/cJSON.h new file mode 100644 index 000000000..e97e5f4cd --- /dev/null +++ b/cJSON.h @@ -0,0 +1,293 @@ +/* + Copyright (c) 2009-2017 Dave Gamble and cJSON contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +*/ + +#ifndef cJSON__h +#define cJSON__h + +#ifdef __cplusplus +extern "C" +{ +#endif + +#if !defined(__WINDOWS__) && (defined(WIN32) || defined(WIN64) || defined(_MSC_VER) || defined(_WIN32)) +#define __WINDOWS__ +#endif + +#ifdef __WINDOWS__ + +/* When compiling for windows, we specify a specific calling convention to avoid issues where we are being called from a project with a different default calling convention. For windows you have 3 define options: + +CJSON_HIDE_SYMBOLS - Define this in the case where you don't want to ever dllexport symbols +CJSON_EXPORT_SYMBOLS - Define this on library build when you want to dllexport symbols (default) +CJSON_IMPORT_SYMBOLS - Define this if you want to dllimport symbol + +For *nix builds that support visibility attribute, you can define similar behavior by + +setting default visibility to hidden by adding +-fvisibility=hidden (for gcc) +or +-xldscope=hidden (for sun cc) +to CFLAGS + +then using the CJSON_API_VISIBILITY flag to "export" the same symbols the way CJSON_EXPORT_SYMBOLS does + +*/ + +#define CJSON_CDECL __cdecl +#define CJSON_STDCALL __stdcall + +/* export symbols by default, this is necessary for copy pasting the C and header file */ +#if !defined(CJSON_HIDE_SYMBOLS) && !defined(CJSON_IMPORT_SYMBOLS) && !defined(CJSON_EXPORT_SYMBOLS) +#define CJSON_EXPORT_SYMBOLS +#endif + +#if defined(CJSON_HIDE_SYMBOLS) +#define CJSON_PUBLIC(type) type CJSON_STDCALL +#elif defined(CJSON_EXPORT_SYMBOLS) +#define CJSON_PUBLIC(type) __declspec(dllexport) type CJSON_STDCALL +#elif defined(CJSON_IMPORT_SYMBOLS) +#define CJSON_PUBLIC(type) __declspec(dllimport) type CJSON_STDCALL +#endif +#else /* !__WINDOWS__ */ +#define CJSON_CDECL +#define CJSON_STDCALL + +#if (defined(__GNUC__) || defined(__SUNPRO_CC) || defined (__SUNPRO_C)) && defined(CJSON_API_VISIBILITY) +#define CJSON_PUBLIC(type) __attribute__((visibility("default"))) type +#else +#define CJSON_PUBLIC(type) type +#endif +#endif + +/* project version */ +#define CJSON_VERSION_MAJOR 1 +#define CJSON_VERSION_MINOR 7 +#define CJSON_VERSION_PATCH 14 + +#include + +/* cJSON Types: */ +#define cJSON_Invalid (0) +#define cJSON_False (1 << 0) +#define cJSON_True (1 << 1) +#define cJSON_NULL (1 << 2) +#define cJSON_Number (1 << 3) +#define cJSON_String (1 << 4) +#define cJSON_Array (1 << 5) +#define cJSON_Object (1 << 6) +#define cJSON_Raw (1 << 7) /* raw json */ + +#define cJSON_IsReference 256 +#define cJSON_StringIsConst 512 + +/* The cJSON structure: */ +typedef struct cJSON +{ + /* next/prev allow you to walk array/object chains. Alternatively, use GetArraySize/GetArrayItem/GetObjectItem */ + struct cJSON *next; + struct cJSON *prev; + /* An array or object item will have a child pointer pointing to a chain of the items in the array/object. */ + struct cJSON *child; + + /* The type of the item, as above. */ + int type; + + /* The item's string, if type==cJSON_String and type == cJSON_Raw */ + char *valuestring; + /* writing to valueint is DEPRECATED, use cJSON_SetNumberValue instead */ + int valueint; + /* The item's number, if type==cJSON_Number */ + double valuedouble; + + /* The item's name string, if this item is the child of, or is in the list of subitems of an object. */ + char *string; +} cJSON; + +typedef struct cJSON_Hooks +{ + /* malloc/free are CDECL on Windows regardless of the default calling convention of the compiler, so ensure the hooks allow passing those functions directly. */ + void *(CJSON_CDECL *malloc_fn)(size_t sz); + void (CJSON_CDECL *free_fn)(void *ptr); +} cJSON_Hooks; + +typedef int cJSON_bool; + +/* Limits how deeply nested arrays/objects can be before cJSON rejects to parse them. + * This is to prevent stack overflows. */ +#ifndef CJSON_NESTING_LIMIT +#define CJSON_NESTING_LIMIT 1000 +#endif + +/* returns the version of cJSON as a string */ +CJSON_PUBLIC(const char*) cJSON_Version(void); + +/* Supply malloc, realloc and free functions to cJSON */ +CJSON_PUBLIC(void) cJSON_InitHooks(cJSON_Hooks* hooks); + +/* Memory Management: the caller is always responsible to free the results from all variants of cJSON_Parse (with cJSON_Delete) and cJSON_Print (with stdlib free, cJSON_Hooks.free_fn, or cJSON_free as appropriate). The exception is cJSON_PrintPreallocated, where the caller has full responsibility of the buffer. */ +/* Supply a block of JSON, and this returns a cJSON object you can interrogate. */ +CJSON_PUBLIC(cJSON *) cJSON_Parse(const char *value); +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLength(const char *value, size_t buffer_length); +/* ParseWithOpts allows you to require (and check) that the JSON is null terminated, and to retrieve the pointer to the final byte parsed. */ +/* If you supply a ptr in return_parse_end and parsing fails, then return_parse_end will contain a pointer to the error so will match cJSON_GetErrorPtr(). */ +CJSON_PUBLIC(cJSON *) cJSON_ParseWithOpts(const char *value, const char **return_parse_end, cJSON_bool require_null_terminated); +CJSON_PUBLIC(cJSON *) cJSON_ParseWithLengthOpts(const char *value, size_t buffer_length, const char **return_parse_end, cJSON_bool require_null_terminated); + +/* Render a cJSON entity to text for transfer/storage. */ +CJSON_PUBLIC(char *) cJSON_Print(const cJSON *item); +/* Render a cJSON entity to text for transfer/storage without any formatting. */ +CJSON_PUBLIC(char *) cJSON_PrintUnformatted(const cJSON *item); +/* Render a cJSON entity to text using a buffered strategy. prebuffer is a guess at the final size. guessing well reduces reallocation. fmt=0 gives unformatted, =1 gives formatted */ +CJSON_PUBLIC(char *) cJSON_PrintBuffered(const cJSON *item, int prebuffer, cJSON_bool fmt); +/* Render a cJSON entity to text using a buffer already allocated in memory with given length. Returns 1 on success and 0 on failure. */ +/* NOTE: cJSON is not always 100% accurate in estimating how much memory it will use, so to be safe allocate 5 bytes more than you actually need */ +CJSON_PUBLIC(cJSON_bool) cJSON_PrintPreallocated(cJSON *item, char *buffer, const int length, const cJSON_bool format); +/* Delete a cJSON entity and all subentities. */ +CJSON_PUBLIC(void) cJSON_Delete(cJSON *item); + +/* Returns the number of items in an array (or object). */ +CJSON_PUBLIC(int) cJSON_GetArraySize(const cJSON *array); +/* Retrieve item number "index" from array "array". Returns NULL if unsuccessful. */ +CJSON_PUBLIC(cJSON *) cJSON_GetArrayItem(const cJSON *array, int index); +/* Get item "string" from object. Case insensitive. */ +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItem(const cJSON * const object, const char * const string); +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItemCaseSensitive(const cJSON * const object, const char * const string); +CJSON_PUBLIC(cJSON_bool) cJSON_HasObjectItem(const cJSON *object, const char *string); +/* For analysing failed parses. This returns a pointer to the parse error. You'll probably need to look a few chars back to make sense of it. Defined when cJSON_Parse() returns 0. 0 when cJSON_Parse() succeeds. */ +CJSON_PUBLIC(const char *) cJSON_GetErrorPtr(void); + +/* Check item type and return its value */ +CJSON_PUBLIC(char *) cJSON_GetStringValue(const cJSON * const item); +CJSON_PUBLIC(double) cJSON_GetNumberValue(const cJSON * const item); + +/* These functions check the type of an item */ +CJSON_PUBLIC(cJSON_bool) cJSON_IsInvalid(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsFalse(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsTrue(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsBool(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsNull(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsNumber(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsString(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsArray(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsObject(const cJSON * const item); +CJSON_PUBLIC(cJSON_bool) cJSON_IsRaw(const cJSON * const item); + +/* These calls create a cJSON item of the appropriate type. */ +CJSON_PUBLIC(cJSON *) cJSON_CreateNull(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateTrue(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateFalse(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateBool(cJSON_bool boolean); +CJSON_PUBLIC(cJSON *) cJSON_CreateNumber(double num); +CJSON_PUBLIC(cJSON *) cJSON_CreateString(const char *string); +/* raw json */ +CJSON_PUBLIC(cJSON *) cJSON_CreateRaw(const char *raw); +CJSON_PUBLIC(cJSON *) cJSON_CreateArray(void); +CJSON_PUBLIC(cJSON *) cJSON_CreateObject(void); + +/* Create a string where valuestring references a string so + * it will not be freed by cJSON_Delete */ +CJSON_PUBLIC(cJSON *) cJSON_CreateStringReference(const char *string); +/* Create an object/array that only references it's elements so + * they will not be freed by cJSON_Delete */ +CJSON_PUBLIC(cJSON *) cJSON_CreateObjectReference(const cJSON *child); +CJSON_PUBLIC(cJSON *) cJSON_CreateArrayReference(const cJSON *child); + +/* These utilities create an Array of count items. + * The parameter count cannot be greater than the number of elements in the number array, otherwise array access will be out of bounds.*/ +CJSON_PUBLIC(cJSON *) cJSON_CreateIntArray(const int *numbers, int count); +CJSON_PUBLIC(cJSON *) cJSON_CreateFloatArray(const float *numbers, int count); +CJSON_PUBLIC(cJSON *) cJSON_CreateDoubleArray(const double *numbers, int count); +CJSON_PUBLIC(cJSON *) cJSON_CreateStringArray(const char *const *strings, int count); + +/* Append item to the specified array/object. */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToArray(cJSON *array, cJSON *item); +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObject(cJSON *object, const char *string, cJSON *item); +/* Use this when string is definitely const (i.e. a literal, or as good as), and will definitely survive the cJSON object. + * WARNING: When this function was used, make sure to always check that (item->type & cJSON_StringIsConst) is zero before + * writing to `item->string` */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemToObjectCS(cJSON *object, const char *string, cJSON *item); +/* Append reference to item to the specified array/object. Use this when you want to add an existing cJSON to a new cJSON, but don't want to corrupt your existing cJSON. */ +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item); +CJSON_PUBLIC(cJSON_bool) cJSON_AddItemReferenceToObject(cJSON *object, const char *string, cJSON *item); + +/* Remove/Detach items from Arrays/Objects. */ +CJSON_PUBLIC(cJSON *) cJSON_DetachItemViaPointer(cJSON *parent, cJSON * const item); +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromArray(cJSON *array, int which); +CJSON_PUBLIC(void) cJSON_DeleteItemFromArray(cJSON *array, int which); +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObject(cJSON *object, const char *string); +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObjectCaseSensitive(cJSON *object, const char *string); +CJSON_PUBLIC(void) cJSON_DeleteItemFromObject(cJSON *object, const char *string); +CJSON_PUBLIC(void) cJSON_DeleteItemFromObjectCaseSensitive(cJSON *object, const char *string); + +/* Update array items. */ +CJSON_PUBLIC(cJSON_bool) cJSON_InsertItemInArray(cJSON *array, int which, cJSON *newitem); /* Shifts pre-existing items to the right. */ +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemViaPointer(cJSON * const parent, cJSON * const item, cJSON * replacement); +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInArray(cJSON *array, int which, cJSON *newitem); +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObject(cJSON *object,const char *string,cJSON *newitem); +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemInObjectCaseSensitive(cJSON *object,const char *string,cJSON *newitem); + +/* Duplicate a cJSON item */ +CJSON_PUBLIC(cJSON *) cJSON_Duplicate(const cJSON *item, cJSON_bool recurse); +/* Duplicate will create a new, identical cJSON item to the one you pass, in new memory that will + * need to be released. With recurse!=0, it will duplicate any children connected to the item. + * The item->next and ->prev pointers are always zero on return from Duplicate. */ +/* Recursively compare two cJSON items for equality. If either a or b is NULL or invalid, they will be considered unequal. + * case_sensitive determines if object keys are treated case sensitive (1) or case insensitive (0) */ +CJSON_PUBLIC(cJSON_bool) cJSON_Compare(const cJSON * const a, const cJSON * const b, const cJSON_bool case_sensitive); + +/* Minify a strings, remove blank characters(such as ' ', '\t', '\r', '\n') from strings. + * The input pointer json cannot point to a read-only address area, such as a string constant, + * but should point to a readable and writable adress area. */ +CJSON_PUBLIC(void) cJSON_Minify(char *json); + +/* Helper functions for creating and adding items to an object at the same time. + * They return the added item or NULL on failure. */ +CJSON_PUBLIC(cJSON*) cJSON_AddNullToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddTrueToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddFalseToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddBoolToObject(cJSON * const object, const char * const name, const cJSON_bool boolean); +CJSON_PUBLIC(cJSON*) cJSON_AddNumberToObject(cJSON * const object, const char * const name, const double number); +CJSON_PUBLIC(cJSON*) cJSON_AddStringToObject(cJSON * const object, const char * const name, const char * const string); +CJSON_PUBLIC(cJSON*) cJSON_AddRawToObject(cJSON * const object, const char * const name, const char * const raw); +CJSON_PUBLIC(cJSON*) cJSON_AddObjectToObject(cJSON * const object, const char * const name); +CJSON_PUBLIC(cJSON*) cJSON_AddArrayToObject(cJSON * const object, const char * const name); + +/* When assigning an integer value, it needs to be propagated to valuedouble too. */ +#define cJSON_SetIntValue(object, number) ((object) ? (object)->valueint = (object)->valuedouble = (number) : (number)) +/* helper for the cJSON_SetNumberValue macro */ +CJSON_PUBLIC(double) cJSON_SetNumberHelper(cJSON *object, double number); +#define cJSON_SetNumberValue(object, number) ((object != NULL) ? cJSON_SetNumberHelper(object, (double)number) : (number)) +/* Change the valuestring of a cJSON_String object, only takes effect when type of object is cJSON_String */ +CJSON_PUBLIC(char*) cJSON_SetValuestring(cJSON *object, const char *valuestring); + +/* Macro for iterating over an array or object */ +#define cJSON_ArrayForEach(element, array) for(element = (array != NULL) ? (array)->child : NULL; element != NULL; element = element->next) + +/* malloc/free objects using the malloc/free functions that have been set with cJSON_InitHooks */ +CJSON_PUBLIC(void *) cJSON_malloc(size_t size); +CJSON_PUBLIC(void) cJSON_free(void *object); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/expr.c b/expr.c new file mode 100644 index 000000000..157539d32 --- /dev/null +++ b/expr.c @@ -0,0 +1,994 @@ +/* Filtering of objects based on simple expressions. + * This powers the FILTER option of Vector Sets, but it is otherwise + * general code to be used when we want to tell if a given object (with fields) + * passes or fails a given test for scalars, strings, ... + * + * Copyright(C) 2024-2025 Salvatore Sanfilippo. All Rights Reserved. + */ + +#include +#include +#include +#include +#include +#include "cJSON.h" + +#ifdef TEST_MAIN +#define RedisModule_Alloc malloc +#define RedisModule_Realloc realloc +#define RedisModule_Free free +#define RedisModule_Strdup strdup +#endif + +#define EXPR_TOKEN_EOF 0 +#define EXPR_TOKEN_NUM 1 +#define EXPR_TOKEN_STR 2 +#define EXPR_TOKEN_TUPLE 3 +#define EXPR_TOKEN_SELECTOR 4 +#define EXPR_TOKEN_OP 5 + +#define EXPR_OP_OPAREN 0 /* ( */ +#define EXPR_OP_CPAREN 1 /* ) */ +#define EXPR_OP_NOT 2 /* ! */ +#define EXPR_OP_POW 3 /* ** */ +#define EXPR_OP_MULT 4 /* * */ +#define EXPR_OP_DIV 5 /* / */ +#define EXPR_OP_MOD 6 /* % */ +#define EXPR_OP_SUM 7 /* + */ +#define EXPR_OP_DIFF 8 /* - */ +#define EXPR_OP_GT 9 /* > */ +#define EXPR_OP_GTE 10 /* >= */ +#define EXPR_OP_LT 11 /* < */ +#define EXPR_OP_LTE 12 /* <= */ +#define EXPR_OP_EQ 13 /* == */ +#define EXPR_OP_NEQ 14 /* != */ +#define EXPR_OP_IN 15 /* in */ +#define EXPR_OP_AND 16 /* and */ +#define EXPR_OP_OR 17 /* or */ + +/* This structure represents a token in our expression. It's either + * literals like 4, "foo", or operators like "+", "-", "and", or + * json selectors, that start with a dot: ".age", ".properties.somearray[1]" */ +typedef struct exprtoken { + int refcount; // Reference counting for memory reclaiming. + int token_type; // Token type of the just parsed token. + int offset; // Chars offset in expression. + union { + double num; // Value for EXPR_TOKEN_NUM. + struct { + char *start; // String pointer for EXPR_TOKEN_STR / SELECTOR. + size_t len; // String len for EXPR_TOKEN_STR / SELECTOR. + char *heapstr; // True if we have a private allocation for this + // string. When possible, it just references to the + // string expression we compiled, exprstate->expr. + } str; + int opcode; // Opcode ID for EXPR_TOKEN_OP. + struct { + struct exprtoken **ele; + size_t len; + } tuple; // Tuples are like [1, 2, 3] for "in" operator. + }; +} exprtoken; + +/* Simple stack of expr tokens. This is used both to represent the stack + * of values and the stack of operands during VM execution. */ +typedef struct exprstack { + exprtoken **items; + int numitems; + int allocsize; +} exprstack; + +typedef struct exprstate { + char *expr; /* Expression string to compile. Note that + * expression token strings point directly to this + * string. */ + char *p; // Currnet position inside 'expr', while parsing. + + // Virtual machine state. + exprstack values_stack; + exprstack ops_stack; // Operator stack used during compilation. + exprstack tokens; // Expression processed into a sequence of tokens. + exprstack program; // Expression compiled into opcodes and values. +} exprstate; + +/* Valid operators. */ +struct { + char *opname; + int oplen; + int opcode; + int precedence; + int arity; +} ExprOptable[] = { + {"(", 1, EXPR_OP_OPAREN, 7, 0}, + {")", 1, EXPR_OP_CPAREN, 7, 0}, + {"!", 1, EXPR_OP_NOT, 6, 1}, + {"not", 3, EXPR_OP_NOT, 6, 1}, + {"**", 2, EXPR_OP_POW, 5, 2}, + {"*", 1, EXPR_OP_MULT, 4, 2}, + {"/", 1, EXPR_OP_DIV, 4, 2}, + {"%", 1, EXPR_OP_MOD, 4, 2}, + {"+", 1, EXPR_OP_SUM, 3, 2}, + {"-", 1, EXPR_OP_DIFF, 3, 2}, + {">", 1, EXPR_OP_GT, 2, 2}, + {">=", 2, EXPR_OP_GTE, 2, 2}, + {"<", 1, EXPR_OP_LT, 2, 2}, + {"<=", 2, EXPR_OP_LTE, 2, 2}, + {"==", 2, EXPR_OP_EQ, 2, 2}, + {"!=", 2, EXPR_OP_NEQ, 2, 2}, + {"in", 2, EXPR_OP_IN, 2, 2}, + {"and", 3, EXPR_OP_AND, 1, 2}, + {"&&", 2, EXPR_OP_AND, 1, 2}, + {"or", 2, EXPR_OP_OR, 0, 2}, + {"||", 2, EXPR_OP_OR, 0, 2}, + {NULL, 0, 0, 0, 0} // Terminator. +}; + +#define EXPR_OP_SPECIALCHARS "+-*%/!()<>=|&" +#define EXPR_SELECTOR_SPECIALCHARS "_-" + +/* ================================ Expr token ============================== */ + +/* Return an heap allocated token of the specified type, setting the + * reference count to 1. */ +exprtoken *exprNewToken(int type) { + exprtoken *t = RedisModule_Alloc(sizeof(exprtoken)); + memset(t,0,sizeof(*t)); + t->token_type = type; + t->refcount = 1; + return t; +} + +/* Generic free token function, can be used to free stack allocated + * objects (in this case the pointer itself will not be freed) or + * heap allocated objects. See the wrappers below. */ +void exprTokenRelease(exprtoken *t) { + if (t == NULL) return; + + if (t->refcount <= 0) { + printf("exprTokenRelease() against a token with refcount %d!\n" + "Aborting program execution\n", + t->refcount); + exit(1); + } + t->refcount--; + if (t->refcount > 0) return; + + // We reached refcount 0: free the object. + if (t->token_type == EXPR_TOKEN_STR) { + if (t->str.heapstr != NULL) RedisModule_Free(t->str.heapstr); + } else if (t->token_type == EXPR_TOKEN_TUPLE) { + for (size_t j = 0; j < t->tuple.len; j++) + exprTokenRelease(t->tuple.ele[j]); + if (t->tuple.ele) RedisModule_Free(t->tuple.ele); + } + RedisModule_Free(t); +} + +void exprTokenRetain(exprtoken *t) { + t->refcount++; +} + +/* ============================== Stack handling ============================ */ + +#include +#include + +#define EXPR_STACK_INITIAL_SIZE 16 + +/* Initialize a new expression stack. */ +void exprStackInit(exprstack *stack) { + stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE); + stack->numitems = 0; + stack->allocsize = EXPR_STACK_INITIAL_SIZE; +} + +/* Push a token pointer onto the stack. Does not increment the refcount + * of the token: it is up to the caller doing this. */ +void exprStackPush(exprstack *stack, exprtoken *token) { + /* Check if we need to grow the stack. */ + if (stack->numitems == stack->allocsize) { + size_t newsize = stack->allocsize * 2; + exprtoken **newitems = + RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize); + stack->items = newitems; + stack->allocsize = newsize; + } + stack->items[stack->numitems] = token; + stack->numitems++; +} + +/* Pop a token pointer from the stack. Return NULL if the stack is + * empty. Does NOT recrement the refcount of the token, it's up to the + * caller to do so, as the new owner of the reference. */ +exprtoken *exprStackPop(exprstack *stack) { + if (stack->numitems == 0) return NULL; + stack->numitems--; + return stack->items[stack->numitems]; +} + +/* Just return the last element pushed, without consuming it nor altering + * the reference count. */ +exprtoken *exprStackPeek(exprstack *stack) { + if (stack->numitems == 0) return NULL; + return stack->items[stack->numitems-1]; +} + +/* Free the stack structure state, including the items it contains, that are + * assumed to be heap allocated. The passed pointer itself is not freed. */ +void exprStackFree(exprstack *stack) { + for (int j = 0; j < stack->numitems; j++) + exprTokenRelease(stack->items[j]); + RedisModule_Free(stack->items); +} + +/* Just reset the stack removing all the items, but leaving it in a state + * that makes it still usable for new elements. */ +void exprStackReset(exprstack *stack) { + for (int j = 0; j < stack->numitems; j++) + exprTokenRelease(stack->items[j]); + stack->numitems = 0; +} + +/* =========================== Expression compilation ======================= */ + +void exprConsumeSpaces(exprstate *es) { + while(es->p[0] && isspace(es->p[0])) es->p++; +} + +/* Parse an operator, trying to match the longer match in the + * operators table. */ +exprtoken *exprParseOperator(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_OP); + char *start = es->p; + + while(es->p[0] && + (isalpha(es->p[0]) || + strchr(EXPR_OP_SPECIALCHARS,es->p[0]) != NULL)) + { + es->p++; + } + + int matchlen = es->p - start; + int bestlen = 0; + int j; + + // Find the longest matching operator. + for (j = 0; ExprOptable[j].opname != NULL; j++) { + if (ExprOptable[j].oplen > matchlen) continue; + if (memcmp(ExprOptable[j].opname, start, ExprOptable[j].oplen) != 0) + { + continue; + } + if (ExprOptable[j].oplen > bestlen) { + t->opcode = ExprOptable[j].opcode; + bestlen = ExprOptable[j].oplen; + } + } + if (bestlen == 0) { + exprTokenRelease(t); + return NULL; + } else { + es->p = start + bestlen; + } + return t; +} + +// Valid selector charset. +static int is_selector_char(int c) { + return (isalpha(c) || + isdigit(c) || + strchr(EXPR_SELECTOR_SPECIALCHARS,c) != NULL); +} + +/* Parse selectors, they start with a dot and can have alphanumerical + * or few special chars. */ +exprtoken *exprParseSelector(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_SELECTOR); + es->p++; // Skip dot. + char *start = es->p; + + while(es->p[0] && is_selector_char(es->p[0])) es->p++; + int matchlen = es->p - start; + t->str.start = start; + t->str.len = matchlen; + return t; +} + +exprtoken *exprParseNumber(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_NUM); + char num[64]; + int idx = 0; + while(isdigit(es->p[0]) || es->p[0] == '.' || es->p[0] == 'e' || + es->p[0] == 'E' || (idx == 0 && es->p[0] == '-')) + { + if (idx >= (int)sizeof(num)-1) { + exprTokenRelease(t); + return NULL; + } + num[idx++] = es->p[0]; + es->p++; + } + num[idx] = 0; + + char *endptr; + t->num = strtod(num, &endptr); + if (*endptr != '\0') { + exprTokenRelease(t); + return NULL; + } + return t; +} + +exprtoken *exprParseString(exprstate *es) { + char quote = es->p[0]; /* Store the quote type (' or "). */ + es->p++; /* Skip opening quote. */ + + exprtoken *t = exprNewToken(EXPR_TOKEN_STR); + t->str.start = es->p; + + while(es->p[0] != '\0') { + if (es->p[0] == '\\' && es->p[1] != '\0') { + es->p += 2; // Skip escaped char. + continue; + } + if (es->p[0] == quote) { + t->str.len = es->p - t->str.start; + es->p++; // Skip closing quote. + return t; + } + es->p++; + } + /* If we reach here, string was not terminated. */ + exprTokenRelease(t); + return NULL; +} + +/* Parse a tuple of the form [1, "foo", 42]. No nested tuples are + * supported. This type is useful mostly to be used with the "IN" + * operator. */ +exprtoken *exprParseTuple(exprstate *es) { + exprtoken *t = exprNewToken(EXPR_TOKEN_TUPLE); + t->tuple.ele = NULL; + t->tuple.len = 0; + es->p++; /* Skip opening '['. */ + + size_t allocated = 0; + while(1) { + exprConsumeSpaces(es); + + /* Check for empty tuple or end. */ + if (es->p[0] == ']') { + es->p++; + break; + } + + /* Grow tuple array if needed. */ + if (t->tuple.len == allocated) { + size_t newsize = allocated == 0 ? 4 : allocated * 2; + exprtoken **newele = RedisModule_Realloc(t->tuple.ele, + sizeof(exprtoken*) * newsize); + t->tuple.ele = newele; + allocated = newsize; + } + + /* Parse tuple element. */ + exprtoken *ele = NULL; + if (isdigit(es->p[0]) || es->p[0] == '-') { + ele = exprParseNumber(es); + } else if (es->p[0] == '"' || es->p[0] == '\'') { + ele = exprParseString(es); + } else { + exprTokenRelease(t); + return NULL; + } + + /* Error parsing number/string? */ + if (ele == NULL) { + exprTokenRelease(t); + return NULL; + } + + /* Store element if no error was detected. */ + t->tuple.ele[t->tuple.len] = ele; + t->tuple.len++; + + /* Check for next element. */ + exprConsumeSpaces(es); + if (es->p[0] == ']') { + es->p++; + break; + } + if (es->p[0] != ',') { + exprTokenRelease(t); + return NULL; + } + es->p++; /* Skip comma. */ + } + return t; +} + +/* Deallocate the object returned by exprCompile(). */ +void exprFree(exprstate *es) { + if (es == NULL) return; + + /* Free the original expression string. */ + if (es->expr) RedisModule_Free(es->expr); + + /* Free all stacks. */ + exprStackFree(&es->values_stack); + exprStackFree(&es->ops_stack); + exprStackFree(&es->tokens); + exprStackFree(&es->program); + + /* Free the state object itself. */ + RedisModule_Free(es); +} + +/* Split the provided expression into a stack of tokens. Returns + * 0 on success, 1 on error. */ +int exprTokenize(exprstate *es, int *errpos) { + /* Main parsing loop. */ + while(1) { + exprConsumeSpaces(es); + + /* Set a flag to see if we can consider the - part of the + * number, or an operator. */ + int minus_is_number = 0; // By default is an operator. + + exprtoken *last = exprStackPeek(&es->tokens); + if (last == NULL) { + /* If we are at the start of an expression, the minus is + * considered a number. */ + minus_is_number = 1; + } else if (last->token_type == EXPR_TOKEN_OP && + last->opcode != EXPR_OP_CPAREN) + { + /* Also, if the previous token was an operator, the minus + * is considered a number, unless the previous operator is + * a closing parens. In such case it's like (...) -5, or alike + * and we want to emit an operator. */ + minus_is_number = 1; + } + + /* Parse based on the current character. */ + exprtoken *current = NULL; + if (*es->p == '\0') { + current = exprNewToken(EXPR_TOKEN_EOF); + } else if (isdigit(*es->p) || + (minus_is_number && *es->p == '-' && isdigit(es->p[1]))) + { + current = exprParseNumber(es); + } else if (*es->p == '"' || *es->p == '\'') { + current = exprParseString(es); + } else if (*es->p == '.' && is_selector_char(es->p[1])) { + current = exprParseSelector(es); + } else if (isalpha(*es->p) || strchr(EXPR_OP_SPECIALCHARS, *es->p)) { + current = exprParseOperator(es); + } else if (*es->p == '[') { + current = exprParseTuple(es); + } + + if (current == NULL) { + if (errpos) *errpos = es->p - es->expr; + return 1; // Syntax Error. + } + + /* Push the current token to tokens stack. */ + exprStackPush(&es->tokens, current); + if (current->token_type == EXPR_TOKEN_EOF) break; + } + return 0; +} + +/* Helper function to get operator precedence from the operator table. */ +int exprGetOpPrecedence(int opcode) { + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == opcode) + return ExprOptable[i].precedence; + } + return -1; +} + +/* Helper function to get operator arity from the operator table. */ +int exprGetOpArity(int opcode) { + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == opcode) + return ExprOptable[i].arity; + } + return -1; +} + +/* Process an operator during compilation. Returns 0 on success, 1 on error. + * This function will retain a reference of the operator 'op' in case it + * is pushed on the operators stack. */ +int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *errpos) { + if (op->opcode == EXPR_OP_OPAREN) { + // This is just a marker for us. Do nothing. + exprStackPush(&es->ops_stack, op); + exprTokenRetain(op); + return 0; + } + + if (op->opcode == EXPR_OP_CPAREN) { + /* Process operators until we find the matching opening parenthesis. */ + while (1) { + exprtoken *top_op = exprStackPop(&es->ops_stack); + if (top_op == NULL) { + if (errpos) *errpos = op->offset; + return 1; + } + + if (top_op->opcode == EXPR_OP_OPAREN) { + /* Open parethesis found. Our work finished. */ + exprTokenRelease(top_op); + return 0; + } + + int arity = exprGetOpArity(top_op->opcode); + if (*stack_items < arity) { + exprTokenRelease(top_op); + if (errpos) *errpos = top_op->offset; + return 1; + } + + /* Move the operator on the program stack. */ + exprStackPush(&es->program, top_op); + *stack_items = *stack_items - arity + 1; + } + } + + int curr_prec = exprGetOpPrecedence(op->opcode); + + /* Process operators with higher or equal precedence. */ + while (1) { + exprtoken *top_op = exprStackPeek(&es->ops_stack); + if (top_op == NULL || top_op->opcode == EXPR_OP_OPAREN) break; + + int top_prec = exprGetOpPrecedence(top_op->opcode); + if (top_prec < curr_prec) break; + /* Special case for **: only pop if precedence is strictly higher + * so that the operator is right associative, that is: + * 2 ** 3 ** 2 is evaluated as 2 ** (3 ** 2) == 512 instead + * of (2 ** 3) ** 2 == 64. */ + if (op->opcode == EXPR_OP_POW && top_prec <= curr_prec) break; + + /* Pop and add to program. */ + top_op = exprStackPop(&es->ops_stack); + int arity = exprGetOpArity(top_op->opcode); + if (*stack_items < arity) { + exprTokenRelease(top_op); + if (errpos) *errpos = top_op->offset; + return 1; + } + + /* Move to the program stack. */ + exprStackPush(&es->program, top_op); + *stack_items = *stack_items - arity + 1; + } + + /* Push current operator. */ + exprStackPush(&es->ops_stack, op); + exprTokenRetain(op); + return 0; +} + +/* Compile the expression into a set of push-value and exec-operator + * that exprRun() can execute. The function returns an expstate object + * that can be used for execution of the program. On error, NULL + * is returned, and optionally the position of the error into the + * expression is returned by reference. */ +exprstate *exprCompile(char *expr, int *errpos) { + /* Initialize expression state. */ + exprstate *es = RedisModule_Alloc(sizeof(exprstate)); + es->expr = RedisModule_Strdup(expr); + es->p = es->expr; + + /* Initialize all stacks. */ + exprStackInit(&es->values_stack); + exprStackInit(&es->ops_stack); + exprStackInit(&es->tokens); + exprStackInit(&es->program); + + /* Tokenization. */ + if (exprTokenize(es, errpos)) { + exprFree(es); + return NULL; + } + + /* Compile the expression into a sequence of operations. */ + int stack_items = 0; // Track # of items that would be on the stack + // during execution. This way we can detect arity + // issues at compile time. + + /* Process each token. */ + for (int i = 0; i < es->tokens.numitems; i++) { + exprtoken *token = es->tokens.items[i]; + + if (token->token_type == EXPR_TOKEN_EOF) break; + + /* Handle values (numbers, strings, selectors). */ + if (token->token_type == EXPR_TOKEN_NUM || + token->token_type == EXPR_TOKEN_STR || + token->token_type == EXPR_TOKEN_TUPLE || + token->token_type == EXPR_TOKEN_SELECTOR) + { + exprStackPush(&es->program, token); + exprTokenRetain(token); + stack_items++; + continue; + } + + /* Handle operators. */ + if (token->token_type == EXPR_TOKEN_OP) { + if (exprProcessOperator(es, token, &stack_items, errpos)) { + exprFree(es); + return NULL; + } + continue; + } + } + + /* Process remaining operators on the stack. */ + while (es->ops_stack.numitems > 0) { + exprtoken *op = exprStackPop(&es->ops_stack); + if (op->opcode == EXPR_OP_OPAREN) { + if (errpos) *errpos = op->offset; + exprTokenRelease(op); + exprFree(es); + return NULL; + } + + int arity = exprGetOpArity(op->opcode); + if (stack_items < arity) { + if (errpos) *errpos = op->offset; + exprTokenRelease(op); + exprFree(es); + return NULL; + } + + exprStackPush(&es->program, op); + stack_items = stack_items - arity + 1; + } + + /* Verify that exactly one value would remain on the stack after + * execution. We could also check that such value is a number, but this + * would make the code more complex without much gains. */ + if (stack_items != 1) { + if (errpos) { + /* Point to the last token's offset for error reporting. */ + exprtoken *last = es->tokens.items[es->tokens.numitems - 1]; + *errpos = last->offset; + } + exprFree(es); + return NULL; + } + return es; +} + +/* ============================ Expression execution ======================== */ + +/* Convert a token to its numeric value. For strings we attempt to parse them + * as numbers, returning 0 if conversion fails. */ +double exprTokenToNum(exprtoken *t) { + char buf[128]; + if (t->token_type == EXPR_TOKEN_NUM) { + return t->num; + } else if (t->token_type == EXPR_TOKEN_STR && t->str.len < sizeof(buf)) { + memcpy(buf, t->str.start, t->str.len); + buf[t->str.len] = '\0'; + char *endptr; + double val = strtod(buf, &endptr); + return *endptr == '\0' ? val : 0; + } else { + return 0; + } +} + +/* Conver obejct to true/false (0 or 1) */ +double exprTokenToBool(exprtoken *t) { + if (t->token_type == EXPR_TOKEN_NUM) { + return t->num != 0; + } else if (t->token_type == EXPR_TOKEN_STR && t->str.len == 0) { + return 0; // Empty string are false, like in Javascript. + } else { + return 1; // Every non numerical type is true. + } +} + +/* Compare two tokens. Returns true if they are equal. */ +int exprTokensEqual(exprtoken *a, exprtoken *b) { + // If both are strings, do string comparison. + if (a->token_type == EXPR_TOKEN_STR && b->token_type == EXPR_TOKEN_STR) { + return a->str.len == b->str.len && + memcmp(a->str.start, b->str.start, a->str.len) == 0; + } + + // If both are numbers, do numeric comparison. + if (a->token_type == EXPR_TOKEN_NUM && b->token_type == EXPR_TOKEN_NUM) { + return a->num == b->num; + } + + // Mixed types - convert to numbers and compare. + return exprTokenToNum(a) == exprTokenToNum(b); +} + +/* Convert a json object to an expression token. There is only + * limited support for JSON arrays: they must be composed of + * just numbers and strings. Returns NULL if the JSON object + * cannot be converted. */ +exprtoken *exprJsonToToken(cJSON *js) { + if (cJSON_IsNumber(js)) { + exprtoken *obj = exprNewToken(EXPR_TOKEN_NUM); + obj->num = cJSON_GetNumberValue(js); + return obj; + } else if (cJSON_IsString(js)) { + exprtoken *obj = exprNewToken(EXPR_TOKEN_STR); + char *strval = cJSON_GetStringValue(js); + obj->str.heapstr = RedisModule_Strdup(strval); + obj->str.start = obj->str.heapstr; + obj->str.len = strlen(obj->str.heapstr); + return obj; + } else if (cJSON_IsBool(js)) { + exprtoken *obj = exprNewToken(EXPR_TOKEN_NUM); + obj->num = cJSON_IsTrue(js); + return obj; + } else if (cJSON_IsArray(js)) { + // First, scan the array to ensure it only + // contains strings and numbers. Otherwise the + // expression will evaluate to false. + int array_size = cJSON_GetArraySize(js); + + for (int j = 0; j < array_size; j++) { + cJSON *item = cJSON_GetArrayItem(js, j); + if (!cJSON_IsNumber(item) && !cJSON_IsString(item)) return NULL; + } + + // Create a tuple token for the array. + exprtoken *obj = exprNewToken(EXPR_TOKEN_TUPLE); + obj->tuple.len = array_size; + obj->tuple.ele = NULL; + if (obj->tuple.len == 0) return obj; // No elements, already ok. + + obj->tuple.ele = + RedisModule_Alloc(sizeof(exprtoken*) * obj->tuple.len); + + // Convert each array element to a token. + for (size_t j = 0; j < obj->tuple.len; j++) { + cJSON *item = cJSON_GetArrayItem(js, j); + if (cJSON_IsNumber(item)) { + exprtoken *eleToken = exprNewToken(EXPR_TOKEN_NUM); + eleToken->num = cJSON_GetNumberValue(item); + obj->tuple.ele[j] = eleToken; + } else if (cJSON_IsString(item)) { + exprtoken *eleToken = exprNewToken(EXPR_TOKEN_STR); + char *strval = cJSON_GetStringValue(item); + eleToken->str.heapstr = RedisModule_Strdup(strval); + eleToken->str.start = eleToken->str.heapstr; + eleToken->str.len = strlen(eleToken->str.heapstr); + obj->tuple.ele[j] = eleToken; + } + } + return obj; + } + return NULL; // No conversion possible for this type. +} + +/* Execute the compiled expression program. Returns 1 if the final stack value + * evaluates to true, 0 otherwise. Also returns 0 if any selector callback + * fails. */ +int exprRun(exprstate *es, char *json, size_t json_len) { + exprStackReset(&es->values_stack); + cJSON *parsed_json = NULL; + + // Execute each instruction in the program. + for (int i = 0; i < es->program.numitems; i++) { + exprtoken *t = es->program.items[i]; + + // Handle selectors by calling the callback. + if (t->token_type == EXPR_TOKEN_SELECTOR) { + if (json != NULL) { + cJSON *attrib = NULL; + if (parsed_json == NULL) { + parsed_json = cJSON_ParseWithLength(json,json_len); + // Will be left to NULL if the above fails. + } + if (parsed_json) { + char item_name[128]; + if (t->str.len > 0 && t->str.len < sizeof(item_name)) { + memcpy(item_name,t->str.start,t->str.len); + item_name[t->str.len] = 0; + attrib = cJSON_GetObjectItem(parsed_json,item_name); + } + /* Fill the token according to the JSON type stored + * at the attribute. */ + if (attrib) { + exprtoken *obj = exprJsonToToken(attrib); + if (obj) { + exprStackPush(&es->values_stack, obj); + continue; + } + } + } + } + + // Selector not found or JSON object not convertible to + // expression tokens. Evaluate the expression to false. + if (parsed_json) cJSON_Delete(parsed_json); + return 0; + } + + // Push non-operator values directly onto the stack. + if (t->token_type != EXPR_TOKEN_OP) { + exprStackPush(&es->values_stack, t); + exprTokenRetain(t); + continue; + } + + // Handle operators. + exprtoken *result = exprNewToken(EXPR_TOKEN_NUM); + + // Pop operands - we know we have enough from compile-time checks. + exprtoken *b = exprStackPop(&es->values_stack); + exprtoken *a = NULL; + if (exprGetOpArity(t->opcode) == 2) { + a = exprStackPop(&es->values_stack); + } + + switch(t->opcode) { + case EXPR_OP_NOT: + result->num = exprTokenToBool(b) == 0 ? 1 : 0; + break; + case EXPR_OP_POW: { + double base = exprTokenToNum(a); + double exp = exprTokenToNum(b); + result->num = pow(base, exp); + break; + } + case EXPR_OP_MULT: + result->num = exprTokenToNum(a) * exprTokenToNum(b); + break; + case EXPR_OP_DIV: + result->num = exprTokenToNum(a) / exprTokenToNum(b); + break; + case EXPR_OP_MOD: { + double va = exprTokenToNum(a); + double vb = exprTokenToNum(b); + result->num = fmod(va, vb); + break; + } + case EXPR_OP_SUM: + result->num = exprTokenToNum(a) + exprTokenToNum(b); + break; + case EXPR_OP_DIFF: + result->num = exprTokenToNum(a) - exprTokenToNum(b); + break; + case EXPR_OP_GT: + result->num = exprTokenToNum(a) > exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_GTE: + result->num = exprTokenToNum(a) >= exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_LT: + result->num = exprTokenToNum(a) < exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_LTE: + result->num = exprTokenToNum(a) <= exprTokenToNum(b) ? 1 : 0; + break; + case EXPR_OP_EQ: + result->num = exprTokensEqual(a, b) ? 1 : 0; + break; + case EXPR_OP_NEQ: + result->num = !exprTokensEqual(a, b) ? 1 : 0; + break; + case EXPR_OP_IN: { + // For 'in' operator, b must be a tuple. + result->num = 0; // Default to false. + if (b->token_type == EXPR_TOKEN_TUPLE) { + for (size_t j = 0; j < b->tuple.len; j++) { + if (exprTokensEqual(a, b->tuple.ele[j])) { + result->num = 1; // Found a match. + break; + } + } + } + break; + } + case EXPR_OP_AND: + result->num = + exprTokenToBool(a) != 0 && exprTokenToBool(b) != 0 ? 1 : 0; + break; + case EXPR_OP_OR: + result->num = + exprTokenToBool(a) != 0 || exprTokenToBool(b) != 0 ? 1 : 0; + break; + default: + // Do nothing: we don't want runtime errors. + break; + } + + // Free operands and push result. + if (a) exprTokenRelease(a); + exprTokenRelease(b); + exprStackPush(&es->values_stack, result); + } + + if (parsed_json) cJSON_Delete(parsed_json); + + // Get final result from stack. + exprtoken *final = exprStackPop(&es->values_stack); + if (final == NULL) return 0; + + // Convert result to boolean. + int retval = exprTokenToBool(final); + exprTokenRelease(final); + return retval; +} + +/* ============================ Simple test main ============================ */ + +#ifdef TEST_MAIN +void exprPrintToken(exprtoken *t) { + switch(t->token_type) { + case EXPR_TOKEN_EOF: + printf("EOF"); + break; + case EXPR_TOKEN_NUM: + printf("NUM:%g", t->num); + break; + case EXPR_TOKEN_STR: + printf("STR:\"%.*s\"", (int)t->str.len, t->str.start); + break; + case EXPR_TOKEN_SELECTOR: + printf("SEL:%.*s", (int)t->str.len, t->str.start); + break; + case EXPR_TOKEN_OP: + printf("OP:"); + for (int i = 0; ExprOptable[i].opname != NULL; i++) { + if (ExprOptable[i].opcode == t->opcode) { + printf("%s", ExprOptable[i].opname); + break; + } + } + break; + default: + printf("UNKNOWN"); + break; + } +} + +void exprPrintStack(exprstack *stack, const char *name) { + printf("%s (%d items):", name, stack->numitems); + for (int j = 0; j < stack->numitems; j++) { + printf(" "); + exprPrintToken(stack->items[j]); + } + printf("\n"); +} + +int main(int argc, char **argv) { + char *testexpr = "(5+2)*3 and .year > 1980 and 'foo' == 'foo'"; + char *testjson = "{\"year\": 1984, \"name\": \"The Matrix\"}"; + if (argc >= 2) testexpr = argv[1]; + if (argc >= 3) testjson = argv[2]; + + printf("Compiling expression: %s\n", testexpr); + + int errpos = 0; + exprstate *es = exprCompile(testexpr,&errpos); + if (es == NULL) { + printf("Compilation failed near \"...%s\"\n", testexpr+errpos); + return 1; + } + + exprPrintStack(&es->tokens, "Tokens"); + exprPrintStack(&es->program, "Program"); + printf("Running against object: %s\n", testjson); + int result = exprRun(es,testjson,strlen(testjson)); + printf("Result1: %s\n", result ? "True" : "False"); + result = exprRun(es,testjson,strlen(testjson)); + printf("Result2: %s\n", result ? "True" : "False"); + + exprFree(es); + return 0; +} +#endif diff --git a/hnsw.c b/hnsw.c index 2a6ac4ff7..fd284e635 100644 --- a/hnsw.c +++ b/hnsw.c @@ -70,9 +70,9 @@ * orphaned of one link. */ -void (*hfree)(void *p) = free; -void *(*hmalloc)(size_t s) = malloc; -void *(*hrealloc)(void *old, size_t s) = realloc; +static void (*hfree)(void *p) = free; +static void *(*hmalloc)(size_t s) = malloc; +static void *(*hrealloc)(void *old, size_t s) = realloc; void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), void *(*realloc_ptr)(void*, size_t)) @@ -618,9 +618,19 @@ void hnsw_add_node(HNSW *index, hnswNode *node) { } /* Search the specified layer starting from the specified entry point - * to collect 'ef' nodes that are near to 'query'. */ -pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, - uint32_t ef, uint32_t layer, uint32_t slot) + * to collect 'ef' nodes that are near to 'query'. + * + * This function implements optional hybrid search, so that each node + * can be accepted or not based on its associated value. In this case + * a callback 'filter_callback' should be passed, together with a maximum + * effort for the search (number of candidates to evaluate), since even + * with a a low "EF" value we risk that there are too few nodes that satisfy + * the provided filter, and we could trigger a full scan. */ +pqueue *search_layer_with_filter( + HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) { // Mark visited nodes with a never seen epoch. index->current_epoch[slot]++; @@ -633,27 +643,33 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, return NULL; } + // Take track of the total effort: only used when filtering via + // a callback to have a bound effort. + uint32_t evaluated_candidates = 1; + // Add entry point. float dist = hnsw_distance(index, query, entry_point); pq_push(candidates, entry_point, dist); - pq_push(results, entry_point, dist); + if (filter_callback == NULL || + filter_callback(entry_point->value, filter_privdata)) + { + pq_push(results, entry_point, dist); + } entry_point->visited_epoch[slot] = index->current_epoch[slot]; // Process candidates. while (candidates->count > 0) { - // Pop closest element and use its saved distance. + // Max effort. If zero, we keep scanning. + if (filter_callback && + max_candidates && + evaluated_candidates >= max_candidates) break; + float cur_dist; hnswNode *current = pq_pop(candidates, &cur_dist); + evaluated_candidates++; - /* Stop if we can't get better results. Note that this can - * be true only if we already collected 'ef' elements in - * the priority queue. This is why: if we have less than EF - * elements, later in the for loop that checks the neighbors we - * add new elements BOTH in the results and candidates pqueue: this - * means that before accumulating EF elements, the worst candidate - * can be as bad as the worst result, but not worse. */ float furthest = pq_max_distance(results); - if (cur_dist > furthest) break; + if (results->count >= ef && cur_dist > furthest) break; /* Check neighbors. */ for (uint32_t i = 0; i < current->layers[layer].num_links; i++) { @@ -664,11 +680,29 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, neighbor->visited_epoch[slot] = index->current_epoch[slot]; float neighbor_dist = hnsw_distance(index, query, neighbor); - // Add to results if better than current max or results not full. + furthest = pq_max_distance(results); - if (neighbor_dist < furthest || results->count < ef) { - pq_push(candidates, neighbor, neighbor_dist); - pq_push(results, neighbor, neighbor_dist); + if (filter_callback == NULL) { + /* Original HNSW logic when no filtering: + * Add to results if better than current max or + * results not full. */ + if (neighbor_dist < furthest || results->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + pq_push(results, neighbor, neighbor_dist); + } + } else { + /* With filtering: we add candidates even if doesn't match + * the filter, in order to continue to explore the graph. */ + if (neighbor_dist < furthest || candidates->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + } + + /* Add results only if passes filter. */ + if (filter_callback(neighbor->value, filter_privdata)) { + if (neighbor_dist < furthest || results->count < ef) { + pq_push(results, neighbor, neighbor_dist); + } + } } } } @@ -677,6 +711,14 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, return results; } +/* Just a wrapper without hybrid search callback. */ +pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot) +{ + return search_layer_with_filter(index, query, entry_point, ef, layer, slot, + NULL, NULL, 0); +} + /* This function is used in order to initialize a node allocated in the * function stack with the specified vector. The idea is that we can * easily use hnsw_distance() from a vector and the HNSW nodes this way: @@ -738,10 +780,21 @@ void hnsw_free_tmp_node(hnswNode *node, const float *vector) { * arrays must have space for at least 'k' items. * norm_query should be set to 1 if the query vector is already * normalized, otherwise, if 0, the function will copy the vector, - * L2-normalize the copy and search using the normalized version. */ -int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + * L2-normalize the copy and search using the normalized version. + * + * If the filter_privdata callback is passed, only elements passing the + * specified filter (invoked with privdata and the value associated + * to the node as arguments) are returned. In such case, if max_candidates + * is not NULL, it represents the maximum number of nodes to explore, since + * the search may be otherwise unbound if few or no elements pass the + * filter. */ +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, hnswNode **neighbors, float *distances, uint32_t slot, - int query_vector_is_normalized) + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) + { if (!index || !query_vector || !neighbors || k == 0) return -1; if (!index->enter_point) return 0; // Empty index. @@ -769,7 +822,9 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, } /* Search bottom layer (the most densely populated) with ef = k */ - pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot); + pqueue *results = search_layer_with_filter( + index, &query, curr_ep, k, 0, slot, filter_callback, + filter_privdata, max_candidates); if (!results) { hnsw_free_tmp_node(&query, query_vector); return -1; @@ -789,6 +844,16 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, return found; } +/* Wrapper to hnsw_search_with_filter() when no filter is needed. */ +int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized) +{ + return hnsw_search_with_filter(index,query_vector,k,neighbors, + distances,slot,query_vector_is_normalized, + NULL,NULL,0); +} + /* Rescan a node and update the wortst neighbor index. * The followinng two functions are variants of this function to be used * when links are added or removed: they may do less work than a full scan. */ diff --git a/hnsw.h b/hnsw.h index 3d104cc5e..5cc1b1cd2 100644 --- a/hnsw.h +++ b/hnsw.h @@ -119,6 +119,12 @@ hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, int hnsw_search(HNSW *index, const float *query, uint32_t k, hnswNode **neighbors, float *distances, uint32_t slot, int query_vector_is_normalized); +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates); void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec); void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)); diff --git a/tests/filter_expr.py b/tests/filter_expr.py new file mode 100644 index 000000000..13abf7b65 --- /dev/null +++ b/tests/filter_expr.py @@ -0,0 +1,177 @@ +from test import TestCase + +class VSIMFilterExpressions(TestCase): + def getname(self): + return "VSIM FILTER expressions basic functionality" + + def test(self): + # Create a small set of vectors with different attributes + + # Basic vectors for testing - all orthogonal for clear results + vec1 = [1, 0, 0, 0] + vec2 = [0, 1, 0, 0] + vec3 = [0, 0, 1, 0] + vec4 = [0, 0, 0, 1] + vec5 = [0.5, 0.5, 0, 0] + + # Add vectors with various attributes + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], f'{self.test_key}:item:1') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:1', + '{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec2], f'{self.test_key}:item:2') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:2', + '{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec3], f'{self.test_key}:item:3') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:3', + '{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}') + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec4], f'{self.test_key}:item:4') + # Item 4 has no attribute at all + + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec5], f'{self.test_key}:item:5') + self.redis.execute_command('VSETATTR', self.test_key, f'{self.test_key}:item:5', + 'invalid json') # Intentionally malformed JSON + + # Test 1: Basic equality with numbers + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age == 25') + assert len(result) == 1, "Expected 1 result for age == 25" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for age == 25" + + # Test 2: Greater than + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > 25') + assert len(result) == 2, "Expected 2 results for age > 25" + + # Test 3: Less than or equal + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age <= 30') + assert len(result) == 2, "Expected 2 results for age <= 30" + + # Test 4: String equality + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name == "Alice"') + assert len(result) == 1, "Expected 1 result for name == Alice" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for name == Alice" + + # Test 5: String inequality + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name != "Alice"') + assert len(result) == 2, "Expected 2 results for name != Alice" + + # Test 6: Boolean value + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.active') + assert len(result) == 1, "Expected 1 result for .active being true" + + # Test 7: Logical AND + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > 20 and .age < 30') + assert len(result) == 1, "Expected 1 result for 20 < age < 30" + assert result[0].decode() == f'{self.test_key}:item:1', "Expected item:1 for 20 < age < 30" + + # Test 8: Logical OR + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age < 30 or .age > 35') + assert len(result) == 1, "Expected 1 result for age < 30 or age > 35" + + # Test 9: Logical NOT + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '!(.age == 25)') + assert len(result) == 2, "Expected 2 results for NOT(age == 25)" + + # Test 10: The "in" operator with array + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age in [25, 35]') + assert len(result) == 2, "Expected 2 results for age in [25, 35]" + + # Test 11: The "in" operator with strings in array + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.name in ["Alice", "David"]') + assert len(result) == 1, "Expected 1 result for name in [Alice, David]" + + # Test 12: Arithmetic operations - addition + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age + 10 > 40') + assert len(result) == 1, "Expected 1 result for age + 10 > 40" + + # Test 13: Arithmetic operations - multiplication + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age * 2 > 60') + assert len(result) == 1, "Expected 1 result for age * 2 > 60" + + # Test 14: Arithmetic operations - division + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age / 5 == 5') + assert len(result) == 1, "Expected 1 result for age / 5 == 5" + + # Test 15: Arithmetic operations - modulo + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age % 2 == 0') + assert len(result) == 1, "Expected 1 result for age % 2 == 0" + + # Test 16: Power operator + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age ** 2 > 900') + assert len(result) == 1, "Expected 1 result for age^2 > 900" + + # Test 17: Missing attribute (should exclude items missing that attribute) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.missing_field == "value"') + assert len(result) == 0, "Expected 0 results for missing_field == value" + + # Test 18: No attribute set at all + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.any_field') + assert f'{self.test_key}:item:4' not in [item.decode() for item in result], "Item with no attribute should be excluded" + + # Test 19: Malformed JSON + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.any_field') + assert f'{self.test_key}:item:5' not in [item.decode() for item in result], "Item with malformed JSON should be excluded" + + # Test 20: Complex expression combining multiple operators + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '(.age > 20 and .age < 40) and (.city == "Boston" or .city == "New York")') + assert len(result) == 2, "Expected 2 results for the complex expression" + expected_items = [f'{self.test_key}:item:1', f'{self.test_key}:item:2'] + assert set([item.decode() for item in result]) == set(expected_items), "Expected item:1 and item:2 for the complex expression" + + # Test 21: Parentheses to control operator precedence + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.age > (20 + 10)') + assert len(result) == 1, "Expected 1 result for age > (20 + 10)" + + # Test 22: Array access (arrays evaluate to true) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], + 'FILTER', '.scores') + assert len(result) == 3, "Expected 3 results for .scores (arrays evaluate to true)" diff --git a/tests/filter_int.py b/tests/filter_int.py new file mode 100644 index 000000000..0fd1dc132 --- /dev/null +++ b/tests/filter_int.py @@ -0,0 +1,668 @@ +from test import TestCase, generate_random_vector +import struct +import random +import math +import json +import time + +class VSIMFilterAdvanced(TestCase): + def getname(self): + return "VSIM FILTER comprehensive functionality testing" + + def estimated_runtime(self): + return 15 # This test might take up to 15 seconds for the large dataset + + def setup(self): + super().setup() + self.dim = 32 # Vector dimension + self.count = 5000 # Number of vectors for large tests + self.small_count = 50 # Number of vectors for small/quick tests + + # Categories for attributes + self.categories = ["electronics", "furniture", "clothing", "books", "food"] + self.cities = ["New York", "London", "Tokyo", "Paris", "Berlin", "Sydney", "Toronto", "Singapore"] + self.price_ranges = [(10, 50), (50, 200), (200, 1000), (1000, 5000)] + self.years = list(range(2000, 2025)) + + def create_attributes(self, index): + """Create realistic attributes for a vector""" + category = random.choice(self.categories) + city = random.choice(self.cities) + min_price, max_price = random.choice(self.price_ranges) + price = round(random.uniform(min_price, max_price), 2) + year = random.choice(self.years) + in_stock = random.random() > 0.3 # 70% chance of being in stock + rating = round(random.uniform(1, 5), 1) + views = int(random.expovariate(1/1000)) # Exponential distribution for page views + tags = random.sample(["popular", "sale", "new", "limited", "exclusive", "clearance"], + k=random.randint(0, 3)) + + # Add some specific patterns for testing + # Every 10th item has a specific property combination for testing + is_premium = (index % 10 == 0) + + # Create attributes dictionary + attrs = { + "id": index, + "category": category, + "location": city, + "price": price, + "year": year, + "in_stock": in_stock, + "rating": rating, + "views": views, + "tags": tags + } + + if is_premium: + attrs["is_premium"] = True + attrs["special_features"] = ["premium", "warranty", "support"] + + # Add sub-categories for more complex filters + if category == "electronics": + attrs["subcategory"] = random.choice(["phones", "computers", "cameras", "audio"]) + elif category == "furniture": + attrs["subcategory"] = random.choice(["chairs", "tables", "sofas", "beds"]) + elif category == "clothing": + attrs["subcategory"] = random.choice(["shirts", "pants", "dresses", "shoes"]) + + # Add some intentionally missing fields for testing + if random.random() > 0.9: # 10% chance of missing price + del attrs["price"] + + # Some items have promotion field + if random.random() > 0.7: # 30% chance of having a promotion + attrs["promotion"] = random.choice(["discount", "bundle", "gift"]) + + # Create invalid JSON for a small percentage of vectors + if random.random() > 0.98: # 2% chance of having invalid JSON + return "{{invalid json}}" + + return json.dumps(attrs) + + def create_vectors_with_attributes(self, key, count): + """Create vectors and add attributes to them""" + vectors = [] + names = [] + attribute_map = {} # To store attributes for verification + + # Create vectors + for i in range(count): + vec = generate_random_vector(self.dim) + vectors.append(vec) + name = f"{key}:item:{i}" + names.append(name) + + # Add to Redis + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', key, 'FP32', vec_bytes, name) + + # Create and add attributes + attrs = self.create_attributes(i) + self.redis.execute_command('VSETATTR', key, name, attrs) + + # Store attributes for later verification + try: + attribute_map[name] = json.loads(attrs) if '{' in attrs else None + except json.JSONDecodeError: + attribute_map[name] = None + + return vectors, names, attribute_map + + def filter_linear_search(self, vectors, names, query_vector, filter_expr, attribute_map, k=10): + """Perform a linear search with filtering for verification""" + similarities = [] + query_norm = math.sqrt(sum(x*x for x in query_vector)) + + if query_norm == 0: + return [] + + for i, vec in enumerate(vectors): + name = names[i] + attributes = attribute_map.get(name) + + # Skip if doesn't match filter + if not self.matches_filter(attributes, filter_expr): + continue + + vec_norm = math.sqrt(sum(x*x for x in vec)) + if vec_norm == 0: + continue + + dot_product = sum(a*b for a,b in zip(query_vector, vec)) + cosine_sim = dot_product / (query_norm * vec_norm) + distance = 1.0 - cosine_sim + redis_similarity = 1.0 - (distance/2.0) + similarities.append((name, redis_similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:k] + + def matches_filter(self, attributes, filter_expr): + """Filter matching for verification - uses Python eval to handle complex expressions""" + if attributes is None: + return False # No attributes or invalid JSON + + # Replace JSON path selectors with Python dictionary access + py_expr = filter_expr + + # Handle `.field` notation (replace with attributes['field']) + i = 0 + while i < len(py_expr): + if py_expr[i] == '.' and (i == 0 or not py_expr[i-1].isalnum()): + # Find the end of the selector (stops at operators or whitespace) + j = i + 1 + while j < len(py_expr) and (py_expr[j].isalnum() or py_expr[j] == '_'): + j += 1 + + if j > i + 1: # Found a valid selector + field = py_expr[i+1:j] + # Use a safe access pattern that returns a default value based on context + py_expr = py_expr[:i] + f"attributes.get('{field}')" + py_expr[j:] + i = i + len(f"attributes.get('{field}')") + else: + i += 1 + else: + i += 1 + + # Convert not operator if needed + py_expr = py_expr.replace('!', ' not ') + + try: + # Custom evaluation that handles exceptions for missing fields + # by returning False for the entire expression + + # Split the expression on logical operators + parts = [] + for op in [' and ', ' or ']: + if op in py_expr: + parts = py_expr.split(op) + break + + if not parts: # No logical operators found + parts = [py_expr] + + # Try to evaluate each part - if any part fails, + # the whole expression should fail + try: + result = eval(py_expr, {"attributes": attributes}) + return bool(result) + except (TypeError, AttributeError): + # This typically happens when trying to compare None with + # numbers or other types, or when an attribute doesn't exist + return False + except Exception as e: + print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}") + return False + + except Exception as e: + print(f"Error evaluating filter expression '{filter_expr}' as '{py_expr}': {e}") + return False + + def safe_decode(self,item): + return item.decode() if isinstance(item, bytes) else item + + def calculate_recall(self, redis_results, linear_results, k=10): + """Calculate recall (percentage of correct results retrieved)""" + redis_set = set(self.safe_decode(item) for item in redis_results) + linear_set = set(item[0] for item in linear_results[:k]) + + if not linear_set: + return 1.0 # If no linear results, consider it perfect recall + + intersection = redis_set.intersection(linear_set) + return len(intersection) / len(linear_set) + + def test_recall_with_filter(self, filter_expr, ef=500, filter_ef=None): + """Test recall for a given filter expression""" + # Create query vector + query_vec = generate_random_vector(self.dim) + + # First, get ground truth using linear scan + linear_results = self.filter_linear_search( + self.vectors, self.names, query_vec, filter_expr, self.attribute_map, k=50) + + # Calculate true selectivity from ground truth + true_selectivity = len(linear_results) / len(self.names) if self.names else 0 + + # Perform Redis search with filter + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 50, 'WITHSCORES', 'EF', ef, 'FILTER', filter_expr]) + if filter_ef: + cmd_args.extend(['FILTER-EF', filter_ef]) + + start_time = time.time() + redis_results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + # Convert Redis results to dict + redis_items = {} + for i in range(0, len(redis_results), 2): + key = redis_results[i].decode() if isinstance(redis_results[i], bytes) else redis_results[i] + score = float(redis_results[i+1]) + redis_items[key] = score + + # Calculate metrics + recall = self.calculate_recall(redis_items.keys(), linear_results) + selectivity = len(redis_items) / len(self.names) if redis_items else 0 + + # Compare against the true selectivity from linear scan + assert abs(selectivity - true_selectivity) < 0.1, \ + f"Redis selectivity {selectivity:.3f} differs significantly from ground truth {true_selectivity:.3f}" + + # We expect high recall for standard parameters + if ef >= 500 and (filter_ef is None or filter_ef >= 1000): + try: + assert recall >= 0.7, \ + f"Low recall {recall:.2f} for filter '{filter_expr}'" + except AssertionError as e: + # Get items found in each set + redis_items_set = set(redis_items.keys()) + linear_items_set = set(item[0] for item in linear_results) + + # Find items in each set + only_in_redis = redis_items_set - linear_items_set + only_in_linear = linear_items_set - redis_items_set + in_both = redis_items_set & linear_items_set + + # Build comprehensive debug message + debug = f"\nGround Truth: {len(linear_results)} matching items (total vectors: {len(self.vectors)})" + debug += f"\nRedis Found: {len(redis_items)} items with FILTER-EF: {filter_ef or 'default'}" + debug += f"\nItems in both sets: {len(in_both)} (recall: {recall:.4f})" + debug += f"\nItems only in Redis: {len(only_in_redis)}" + debug += f"\nItems only in Ground Truth: {len(only_in_linear)}" + + # Show some example items from each set with their scores + if only_in_redis: + debug += "\n\nTOP 5 ITEMS ONLY IN REDIS:" + sorted_redis = sorted([(k, v) for k, v in redis_items.items()], key=lambda x: x[1], reverse=True) + for i, (item, score) in enumerate(sorted_redis[:5]): + if item in only_in_redis: + debug += f"\n {i+1}. {item} (Score: {score:.4f})" + + # Show attribute that should match filter + attr = self.attribute_map.get(item) + if attr: + debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}" + + if only_in_linear: + debug += "\n\nTOP 5 ITEMS ONLY IN GROUND TRUTH:" + for i, (item, score) in enumerate(linear_results[:5]): + if item in only_in_linear: + debug += f"\n {i+1}. {item} (Score: {score:.4f})" + + # Show attribute that should match filter + attr = self.attribute_map.get(item) + if attr: + debug += f" - Attrs: {attr.get('category', 'N/A')}, Price: {attr.get('price', 'N/A')}" + + # Help identify parsing issues + debug += "\n\nPARSING CHECK:" + debug += f"\nRedis command: VSIM {self.test_key} VALUES {self.dim} [...] FILTER '{filter_expr}'" + + # Check for WITHSCORES handling issues + if len(redis_results) > 0 and len(redis_results) % 2 == 0: + debug += f"\nRedis returned {len(redis_results)} items (looks like item,score pairs)" + debug += f"\nFirst few results: {redis_results[:4]}" + + # Check the filter implementation + debug += "\n\nFILTER IMPLEMENTATION CHECK:" + debug += f"\nFilter expression: '{filter_expr}'" + debug += "\nSample attribute matches from attribute_map:" + count_matching = 0 + for i, (name, attrs) in enumerate(self.attribute_map.items()): + if attrs and self.matches_filter(attrs, filter_expr): + count_matching += 1 + if i < 3: # Show first 3 matches + debug += f"\n - {name}: {attrs}" + debug += f"\nTotal items matching filter in attribute_map: {count_matching}" + + # Check if results array handling could be wrong + debug += "\n\nRESULT ARRAYS CHECK:" + if len(linear_results) >= 1: + debug += f"\nlinear_results[0]: {linear_results[0]}" + if isinstance(linear_results[0], tuple) and len(linear_results[0]) == 2: + debug += " (correct tuple format: (name, score))" + else: + debug += " (UNEXPECTED FORMAT!)" + + # Debug sort order + debug += "\n\nSORTING CHECK:" + if len(linear_results) >= 2: + debug += f"\nGround truth first item score: {linear_results[0][1]}" + debug += f"\nGround truth second item score: {linear_results[1][1]}" + debug += f"\nCorrectly sorted by similarity? {linear_results[0][1] >= linear_results[1][1]}" + + # Re-raise with detailed information + raise AssertionError(str(e) + debug) + + return recall, selectivity, query_time, len(redis_items) + + def test(self): + print(f"\nRunning comprehensive VSIM FILTER tests...") + + # Create a larger dataset for testing + print(f"Creating dataset with {self.count} vectors and attributes...") + self.vectors, self.names, self.attribute_map = self.create_vectors_with_attributes( + self.test_key, self.count) + + # ==== 1. Recall and Precision Testing ==== + print("Testing recall for various filters...") + + # Test basic filters with different selectivity + results = {} + results["category"] = self.test_recall_with_filter('.category == "electronics"') + results["price_high"] = self.test_recall_with_filter('.price > 1000') + results["in_stock"] = self.test_recall_with_filter('.in_stock') + results["rating"] = self.test_recall_with_filter('.rating >= 4') + results["complex1"] = self.test_recall_with_filter('.category == "electronics" and .price < 500') + + print("Filter | Recall | Selectivity | Time (ms) | Results") + print("----------------------------------------------------") + for name, (recall, selectivity, time_ms, count) in results.items(): + print(f"{name:7} | {recall:.3f} | {selectivity:.3f} | {time_ms*1000:.1f} | {count}") + + # ==== 2. Filter Selectivity Performance ==== + print("\nTesting filter selectivity performance...") + + # High selectivity (very few matches) + high_sel_recall, _, high_sel_time, _ = self.test_recall_with_filter('.is_premium') + + # Medium selectivity + med_sel_recall, _, med_sel_time, _ = self.test_recall_with_filter('.price > 100 and .price < 1000') + + # Low selectivity (many matches) + low_sel_recall, _, low_sel_time, _ = self.test_recall_with_filter('.year > 2000') + + print(f"High selectivity recall: {high_sel_recall:.3f}, time: {high_sel_time*1000:.1f}ms") + print(f"Med selectivity recall: {med_sel_recall:.3f}, time: {med_sel_time*1000:.1f}ms") + print(f"Low selectivity recall: {low_sel_recall:.3f}, time: {low_sel_time*1000:.1f}ms") + + # ==== 3. FILTER-EF Parameter Testing ==== + print("\nTesting FILTER-EF parameter...") + + # Test with different FILTER-EF values + filter_expr = '.category == "electronics" and .price > 200' + ef_values = [100, 500, 2000, 5000] + + print("FILTER-EF | Recall | Time (ms)") + print("-----------------------------") + for filter_ef in ef_values: + recall, _, query_time, _ = self.test_recall_with_filter( + filter_expr, ef=500, filter_ef=filter_ef) + print(f"{filter_ef:9} | {recall:.3f} | {query_time*1000:.1f}") + + # Assert that higher FILTER-EF generally gives better recall + low_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=100) + high_ef_recall, _, _, _ = self.test_recall_with_filter(filter_expr, filter_ef=5000) + + # This might not always be true due to randomness, but generally holds + # We use a softer assertion to avoid flaky tests + assert high_ef_recall >= low_ef_recall * 0.8, \ + f"Higher FILTER-EF should generally give better recall: {high_ef_recall:.3f} vs {low_ef_recall:.3f}" + + # ==== 4. Complex Filter Expressions ==== + print("\nTesting complex filter expressions...") + + # Test a variety of complex expressions + complex_filters = [ + '.price > 100 and (.category == "electronics" or .category == "furniture")', + '(.rating > 4 and .in_stock) or (.price < 50 and .views > 1000)', + '.category in ["electronics", "clothing"] and .price > 200 and .rating >= 3', + '(.category == "electronics" and .subcategory == "phones") or (.category == "furniture" and .price > 1000)', + '.year > 2010 and !(.price < 100) and .in_stock' + ] + + print("Expression | Results | Time (ms)") + print("-----------------------------") + for i, expr in enumerate(complex_filters): + try: + _, _, query_time, result_count = self.test_recall_with_filter(expr) + print(f"Complex {i+1} | {result_count:7} | {query_time*1000:.1f}") + except Exception as e: + print(f"Complex {i+1} | Error: {str(e)}") + + # ==== 5. Attribute Type Testing ==== + print("\nTesting different attribute types...") + + type_filters = [ + ('.price > 500', "Numeric"), + ('.category == "books"', "String equality"), + ('.in_stock', "Boolean"), + ('.tags in ["sale", "new"]', "Array membership"), + ('.rating * 2 > 8', "Arithmetic") + ] + + for expr, type_name in type_filters: + try: + _, _, query_time, result_count = self.test_recall_with_filter(expr) + print(f"{type_name:16} | {expr:30} | {result_count:5} results | {query_time*1000:.1f}ms") + except Exception as e: + print(f"{type_name:16} | {expr:30} | Error: {str(e)}") + + # ==== 6. Filter + Count Interaction ==== + print("\nTesting COUNT parameter with filters...") + + filter_expr = '.category == "electronics"' + counts = [5, 20, 100] + + for count in counts: + query_vec = generate_random_vector(self.dim) + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', count, 'WITHSCORES', 'FILTER', filter_expr]) + + results = self.redis.execute_command(*cmd_args) + result_count = len(results) // 2 # Divide by 2 because WITHSCORES returns pairs + + # We expect result count to be at most the requested count + assert result_count <= count, f"Got {result_count} results with COUNT {count}" + print(f"COUNT {count:3} | Got {result_count:3} results") + + # ==== 7. Edge Cases ==== + print("\nTesting edge cases...") + + # Test with no matching items + no_match_expr = '.category == "nonexistent_category"' + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', no_match_expr) + assert len(results) == 0, f"Expected 0 results for non-matching filter, got {len(results)}" + print(f"No matching items: {len(results)} results (expected 0)") + + # Test with invalid filter syntax + try: + self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', '.category === "books"') # Triple equals is invalid + assert False, "Expected error for invalid filter syntax" + except: + print("Invalid filter syntax correctly raised an error") + + # Test with extremely long complex expression + long_expr = ' and '.join([f'.rating > {i/10}' for i in range(10)]) + try: + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', self.dim, + *[str(x) for x in generate_random_vector(self.dim)], + 'FILTER', long_expr) + print(f"Long expression: {len(results)} results") + except Exception as e: + print(f"Long expression error: {str(e)}") + + print("\nComprehensive VSIM FILTER tests completed successfully") + + +class VSIMFilterSelectivityTest(TestCase): + def getname(self): + return "VSIM FILTER selectivity performance benchmark" + + def estimated_runtime(self): + return 8 # This test might take up to 8 seconds + + def setup(self): + super().setup() + self.dim = 32 + self.count = 10000 + self.test_key = f"{self.test_key}:selectivity" # Use a different key + + def create_vector_with_age_attribute(self, name, age): + """Create a vector with a specific age attribute""" + vec = generate_random_vector(self.dim) + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps({"age": age})) + + def test(self): + print("\nRunning VSIM FILTER selectivity benchmark...") + + # Create a dataset where we control the exact selectivity + print(f"Creating controlled dataset with {self.count} vectors...") + + # Create vectors with age attributes from 1 to 100 + for i in range(self.count): + age = (i % 100) + 1 # Ages from 1 to 100 + name = f"{self.test_key}:item:{i}" + self.create_vector_with_age_attribute(name, age) + + # Create a query vector + query_vec = generate_random_vector(self.dim) + + # Test filters with different selectivities + selectivities = [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.99] + results = [] + + print("\nSelectivity | Filter | Results | Time (ms)") + print("--------------------------------------------------") + + for target_selectivity in selectivities: + # Calculate age threshold for desired selectivity + # For example, age <= 10 gives 10% selectivity + age_threshold = int(target_selectivity * 100) + filter_expr = f'.age <= {age_threshold}' + + # Run query and measure time + start_time = time.time() + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr]) + + results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + actual_selectivity = len(results) / min(100, int(target_selectivity * self.count)) + print(f"{target_selectivity:.2f} | {filter_expr:15} | {len(results):7} | {query_time*1000:.1f}") + + # Add assertion to ensure reasonable performance for different selectivities + # For very selective queries (1%), we might need more exploration + if target_selectivity <= 0.05: + # For very selective queries, ensure we can find some results + assert len(results) > 0, f"No results found for {filter_expr}" + else: + # For less selective queries, performance should be reasonable + assert query_time < 1.0, f"Query too slow: {query_time:.3f}s for {filter_expr}" + + print("\nSelectivity benchmark completed successfully") + + +class VSIMFilterComparisonTest(TestCase): + def getname(self): + return "VSIM FILTER EF parameter comparison" + + def estimated_runtime(self): + return 8 # This test might take up to 8 seconds + + def setup(self): + super().setup() + self.dim = 32 + self.count = 5000 + self.test_key = f"{self.test_key}:efparams" # Use a different key + + def create_dataset(self): + """Create a dataset with specific attribute patterns for testing FILTER-EF""" + vectors = [] + names = [] + + # Create vectors with category and quality score attributes + for i in range(self.count): + vec = generate_random_vector(self.dim) + name = f"{self.test_key}:item:{i}" + + # Add vector to Redis + vec_bytes = struct.pack(f'{self.dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, name) + + # Create attributes - we want a very selective filter + # Only 2% of items have category=premium AND quality>90 + category = "premium" if random.random() < 0.1 else random.choice(["standard", "economy", "basic"]) + quality = random.randint(1, 100) + + attrs = { + "id": i, + "category": category, + "quality": quality + } + + self.redis.execute_command('VSETATTR', self.test_key, name, json.dumps(attrs)) + vectors.append(vec) + names.append(name) + + return vectors, names + + def test(self): + print("\nRunning VSIM FILTER-EF parameter comparison...") + + # Create dataset + vectors, names = self.create_dataset() + + # Create a selective filter that matches ~2% of items + filter_expr = '.category == "premium" and .quality > 90' + + # Create query vector + query_vec = generate_random_vector(self.dim) + + # Test different FILTER-EF values + ef_values = [50, 100, 500, 1000, 5000] + results = [] + + print("\nFILTER-EF | Results | Time (ms) | Notes") + print("---------------------------------------") + + baseline_count = None + + for ef in ef_values: + # Run query and measure time + start_time = time.time() + cmd_args = ['VSIM', self.test_key, 'VALUES', self.dim] + cmd_args.extend([str(x) for x in query_vec]) + cmd_args.extend(['COUNT', 100, 'FILTER', filter_expr, 'FILTER-EF', ef]) + + query_results = self.redis.execute_command(*cmd_args) + query_time = time.time() - start_time + + # Set baseline for comparison + if baseline_count is None: + baseline_count = len(query_results) + + recall_rate = len(query_results) / max(1, baseline_count) if baseline_count > 0 else 1.0 + + notes = "" + if ef == 5000: + notes = "Baseline" + elif recall_rate < 0.5: + notes = "Low recall!" + + print(f"{ef:9} | {len(query_results):7} | {query_time*1000:.1f} | {notes}") + results.append((ef, len(query_results), query_time)) + + # If we have enough results at highest EF, check that recall improves with higher EF + if results[-1][1] >= 5: # At least 5 results for highest EF + # Extract result counts + result_counts = [r[1] for r in results] + + # The last result (highest EF) should typically find more results than the first (lowest EF) + # but we use a soft assertion to avoid flaky tests + assert result_counts[-1] >= result_counts[0], \ + f"Higher FILTER-EF should find at least as many results: {result_counts[-1]} vs {result_counts[0]}" + + print("\nFILTER-EF parameter comparison completed successfully") diff --git a/vset.c b/vset.c index 321fbf5c1..24dbaea36 100644 --- a/vset.c +++ b/vset.c @@ -20,6 +20,10 @@ #include #include "hnsw.h" +// We inline directly the expression implementation here so that building +// the module is trivial. +#include "expr.c" + static RedisModuleType *VectorSetType; static uint64_t VectorSetTypeNextId = 0; @@ -48,6 +52,15 @@ struct vsetObject { pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely. uint64_t id; // Unique ID used by threaded VADD to know the // object is still the same. + uint64_t numattribs; // Number of nodes associated with an attribute. +}; + +/* Each node has two associated values: the associated string (the item + * in the set) and potentially a JSON string, that is, the attributes, used + * for hybrid search with the VSIM FILTER option. */ +struct vsetNodeVal { + RedisModuleString *item; + RedisModuleString *attrib; }; /* Create a random projection matrix for dimensionality reduction. @@ -108,13 +121,16 @@ struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type) o->proj_matrix = NULL; o->proj_input_size = 0; + o->numattribs = 0; pthread_rwlock_init(&o->in_use_lock,NULL); - return o; } void vectorSetReleaseNodeValue(void *v) { - RedisModule_FreeString(NULL,v); + struct vsetNodeVal *nv = v; + RedisModule_FreeString(NULL,nv->item); + if (nv->attrib) RedisModule_FreeString(NULL,nv->attrib); + RedisModule_Free(nv); } /* Free the vector set object. */ @@ -142,24 +158,60 @@ const char *vectorSetGetQuantName(struct vsetObject *o) { * * Returns 1 if the element was added, or 0 if the element was already there * and was just updated. */ -int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, int update, int ef) +int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, RedisModuleString *attrib, int update, int ef) { hnswNode *node = RedisModule_DictGet(o->dict,val,NULL); if (node != NULL) { if (update) { - void *old_val = node->value; + struct vsetNodeVal *nv = node->value; /* Pass NULL as value-free function. We want to reuse * the old value. */ hnsw_delete_node(o->hnsw, node, NULL); - node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,old_val,ef); + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); RedisModule_DictReplace(o->dict,val,node); + + /* If attrib != NULL, the user wants that in case of an update we + * update the attribute as well (otherwise it reamins as it was). + * Note that the order of operations is conceinved so that it + * works in case the old attrib and the new attrib pointer is the + * same. */ + if (attrib) { + // Empty attribute string means: unset the attribute during + // the update. + size_t attrlen; + RedisModule_StringPtrLen(attrib,&attrlen); + if (attrlen != 0) { + RedisModule_RetainString(NULL,attrib); + o->numattribs++; + } else { + attrib = NULL; + } + + if (nv->attrib) { + o->numattribs--; + RedisModule_FreeString(NULL,nv->attrib); + } + nv->attrib = attrib; + } } return 0; } - node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,val,ef); - if (!node) return 0; + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = val; + nv->attrib = attrib; + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,nv,ef); + if (node == NULL) { + // XXX Technically in Redis-land we don't have out of memories as we + // crash. However the HNSW library may fail for error in the locking + // libc call. There is understand if this may actually happen or not. + RedisModule_Free(nv); + return 0; + } + if (attrib != NULL) o->numattribs++; RedisModule_DictSet(o->dict,val,node); + RedisModule_RetainString(NULL,val); + if (attrib) RedisModule_RetainString(NULL,attrib); return 1; } @@ -243,11 +295,10 @@ void *VADD_thread(void *arg) { RedisModuleBlockedClient *bc = targ[0]; struct vsetObject *vset = targ[1]; float *vec = targ[3]; - RedisModuleString *val = targ[4]; int ef = (uint64_t)targ[6]; /* Look for candidates... */ - InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, val, ef); + InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, NULL, ef); targ[5] = ic; // Pass the context to the reply callback. /* Unblock the client so that our read reply will be invoked. */ @@ -268,6 +319,7 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModuleString *val = targ[4]; InsertContext *ic = targ[5]; int ef = (uint64_t)targ[6]; + RedisModuleString *attrib = targ[7]; RedisModule_Free(targ); /* Open the key: there are no guarantees it still exists, or contains @@ -300,12 +352,22 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { /* Otherwise try to insert the new element with the neighbors * collected in background. If we fail, do it synchronously again * from scratch. */ + + // First: allocate the dual-ported value for the node. + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = val; + nv->attrib = attrib; + + // Then: insert the node in the HNSW data structure. hnswNode *newnode; if ((newnode = hnsw_try_commit_insert(vset->hnsw, ic)) == NULL) { - newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, val, ef); + newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, nv, ef); + } else { + newnode->value = nv; } RedisModule_DictSet(vset->dict,val,newnode); val = NULL; // Don't free it later. + attrib = NULL; // Dont' free it later. RedisModule_ReplicateVerbatim(ctx); } @@ -313,6 +375,7 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { // Whatever happens is a success... :D RedisModule_ReplyWithLongLong(ctx,1); if (val) RedisModule_FreeString(ctx,val); // Not added? Free it. + if (attrib) RedisModule_FreeString(ctx,attrib); // Not added? Free it. RedisModule_Free(vec); return retval; } @@ -330,6 +393,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { int cas = 0; // Threaded check-and-set style insert. long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes. float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args); + RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB. if (!vec) return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification"); @@ -350,7 +414,10 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModule_Free(vec); return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); } - j++; // skip EF argument. + j++; // skip argument. + } else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) { + attrib = argv[j+1]; + j++; // skip argument. } else if (!strcasecmp(opt, "NOQUANT")) { quant_type = HNSW_QUANT_NONE; } else if (!strcasecmp(opt, "BIN")) { @@ -464,8 +531,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (!cas) { /* Insert vector synchronously. */ - int added = vectorSetInsert(vset,vec,NULL,0,val,1,ef); - if (added) RedisModule_RetainString(ctx,val); + int added = vectorSetInsert(vset,vec,NULL,0,val,attrib,1,ef); RedisModule_Free(vec); RedisModule_ReplyWithLongLong(ctx,added); @@ -478,7 +544,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0); pthread_t tid; - void **targ = RedisModule_Alloc(sizeof(void*)*7); + void **targ = RedisModule_Alloc(sizeof(void*)*8); targ[0] = bc; targ[1] = vset; targ[2] = (void*)(unsigned long)vset->id; @@ -486,7 +552,9 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { targ[4] = val; targ[5] = NULL; // Used later for insertion context. targ[6] = (void*)(unsigned long)ef; + targ[7] = attrib; RedisModule_RetainString(ctx,val); + if (attrib) RedisModule_RetainString(ctx,attrib); if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) { pthread_rwlock_unlock(&vset->in_use_lock); RedisModule_AbortBlock(bc); @@ -499,6 +567,17 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } } +/* HNSW callback to filter items according to a predicate function + * (our FILTER expression in this case). */ +int vectorSetFilterCallback(void *value, void *privdata) { + exprstate *expr = privdata; + struct vsetNodeVal *nv = value; + if (nv->attrib == NULL) return 0; // No attributes? No match. + size_t json_len; + char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len); + return exprRun(expr,json,json_len); +} + /* Common path for the execution of the VSIM command both threaded and * not threaded. Note that 'ctx' may be normal context of a thread safe * context obtained from a blocked client. The locking that is specific @@ -506,7 +585,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { * handles the HNSW locking explicitly. */ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, float *vec, unsigned long count, float epsilon, unsigned long withscores, - unsigned long ef) + unsigned long ef, exprstate *filter_expr, unsigned long filter_ef) { /* In our scan, we can't just collect 'count' elements as * if count is small we would explore the graph in an insufficient @@ -523,7 +602,12 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); float *distances = RedisModule_Alloc(sizeof(float)*ef); int slot = hnsw_acquire_read_slot(vset->hnsw); - unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0); + unsigned int found; + if (filter_expr == NULL) { + found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0); + } else { + found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, distances, slot, 0, vectorSetFilterCallback, filter_expr, filter_ef); + } hnsw_release_read_slot(vset->hnsw,slot); RedisModule_Free(vec); @@ -536,7 +620,8 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, for (unsigned int i = 0; i < found && i < count; i++) { if (distances[i] > epsilon) break; - RedisModule_ReplyWithString(ctx, neighbors[i]->value); + struct vsetNodeVal *nv = neighbors[i]->value; + RedisModule_ReplyWithString(ctx, nv->item); arraylen++; if (withscores) { /* The similarity score is provided in a 0-1 range. */ @@ -551,6 +636,7 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, RedisModule_Free(neighbors); RedisModule_Free(distances); + if (filter_expr) exprFree(filter_expr); } /* VSIM thread handling the blocked client request. */ @@ -566,6 +652,8 @@ void *VSIM_thread(void *arg) { float epsilon = *((float*)targ[4]); unsigned long withscores = (unsigned long)targ[5]; unsigned long ef = (unsigned long)targ[6]; + exprstate *filter_expr = targ[7]; + unsigned long filter_ef = (unsigned long)targ[8]; RedisModule_Free(targ[4]); RedisModule_Free(targ); @@ -573,7 +661,7 @@ void *VSIM_thread(void *arg) { RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); // Run the query. - VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef); // Cleanup. RedisModule_FreeThreadSafeContext(ctx); @@ -582,7 +670,7 @@ void *VSIM_thread(void *arg) { return NULL; } -/* VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */ +/* VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModule_AutoMemory(ctx); @@ -596,6 +684,10 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { long long ef = 0; /* Exploration factor (see HNSW paper) */ double epsilon = 2.0; /* Max cosine distance */ + /* Things computed later. */ + long long filter_ef = 0; + exprstate *filter_expr = NULL; + /* Get key and vector type */ RedisModuleString *key = argv[1]; const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); @@ -704,6 +796,28 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); } j += 2; + } else if (!strcasecmp(opt, "FILTER-EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &filter_ef) != + REDISMODULE_OK || filter_ef <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid FILTER-EF"); + } + j += 2; + } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) { + RedisModuleString *exprarg = argv[j+1]; + size_t exprlen; + char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen); + int errpos; + filter_expr = exprCompile(exprstr,&errpos); + if (filter_expr == NULL) { + if ((size_t)errpos >= exprlen) errpos = 0; + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR syntax error in FILTER expression near: %s", + exprstr+errpos); + } + j += 2; } else { RedisModule_Free(vec); return RedisModule_ReplyWithError(ctx, @@ -712,6 +826,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } int threaded_request = 1; // Run on a thread, by default. + if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes. // Disable threaded for MULTI/EXEC and Lua. if (RedisModule_GetContextFlags(ctx) & @@ -737,7 +852,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); pthread_t tid; - void **targ = RedisModule_Alloc(sizeof(void*)*7); + void **targ = RedisModule_Alloc(sizeof(void*)*9); targ[0] = bc; targ[1] = vset; targ[2] = vec; @@ -746,16 +861,18 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { *((float*)targ[4]) = epsilon; targ[5] = (void*)(unsigned long)withscores; targ[6] = (void*)(unsigned long)ef; + targ[7] = (void*)filter_expr; + targ[8] = (void*)(unsigned long)filter_ef; if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { pthread_rwlock_unlock(&vset->in_use_lock); RedisModule_AbortBlock(bc); RedisModule_Free(vec); RedisModule_Free(targ[4]); RedisModule_Free(targ); - return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread"); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef); } } else { - VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef); } return REDISMODULE_OK; @@ -839,6 +956,8 @@ int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { /* Remove from HNSW graph using the high-level API that handles * locking and cleanup. We pass RedisModule_FreeString as the value * free function since the strings were retained at insertion time. */ + struct vsetNodeVal *nv = node->value; + if (nv->attrib != NULL) vset->numattribs--; hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue); /* Destroy empty vector set. */ @@ -917,6 +1036,92 @@ int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { return REDISMODULE_OK; } +/* VSETATTR key element json + * Set or remove the JSON attribute associated with an element. + * Setting an empty string removes the attribute. + * The command returns one if the attribute was actually updated or + * zero if there is no key or element. */ +int VSETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 4) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithLongLong(ctx, 0); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); + if (!node) + return RedisModule_ReplyWithLongLong(ctx, 0); + + struct vsetNodeVal *nv = node->value; + RedisModuleString *new_attr = argv[3]; + + /* Set or delete the attribute based on the fact it's an empty + * string or not. */ + size_t attrlen; + RedisModule_StringPtrLen(new_attr, &attrlen); + if (attrlen == 0) { + // If we had an attribute before, decrease the count and free it. + if (nv->attrib) { + vset->numattribs--; + RedisModule_FreeString(NULL, nv->attrib); + nv->attrib = NULL; + } + } else { + // If we didn't have an attribute before, increase the count. + // Otherwise free the old one. + if (nv->attrib) { + RedisModule_FreeString(NULL, nv->attrib); + } else { + vset->numattribs++; + } + // Set new attribute. + RedisModule_RetainString(NULL, new_attr); + nv->attrib = new_attr; + } + + RedisModule_ReplyWithLongLong(ctx, 1); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* VGETATTR key element + * Get the JSON attribute associated with an element. + * Returns NIL if the element has no attribute or doesn't exist. */ +int VGETATTR_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 3) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNull(ctx); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + hnswNode *node = RedisModule_DictGet(vset->dict, argv[2], NULL); + if (!node) + return RedisModule_ReplyWithNull(ctx); + + struct vsetNodeVal *nv = node->value; + if (!nv->attrib) + return RedisModule_ReplyWithNull(ctx); + + return RedisModule_ReplyWithString(ctx, nv->attrib); +} + /* ============================== Reflection ================================ */ /* VLINKS key element [WITHSCORES] @@ -971,7 +1176,8 @@ int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) /* Add each neighbor's element value to the array. */ for (uint32_t j = 0; j < node->layers[i].num_links; j++) { - RedisModule_ReplyWithString(ctx, node->layers[i].links[j]->value); + struct vsetNodeVal *nv = node->layers[i].links[j]->value; + RedisModule_ReplyWithString(ctx, nv->item); if (withscores) { float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]); /* Convert distance to similarity score to match @@ -1035,6 +1241,9 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) /* ============================== vset type methods ========================= */ +#define SAVE_FLAG_HAS_PROJMATRIX (1<<0) +#define SAVE_FLAG_HAS_ATTRIBS (1<<1) + /* Save object to RDB */ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { struct vsetObject *vset = value; @@ -1042,9 +1251,13 @@ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count); RedisModule_SaveUnsigned(rdb, vset->hnsw->quant_type); + uint32_t save_flags = 0; + if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX; + if (vset->numattribs != 0) save_flags |= SAVE_FLAG_HAS_ATTRIBS; + RedisModule_SaveUnsigned(rdb, save_flags); + /* Save projection matrix if present */ if (vset->proj_matrix) { - RedisModule_SaveUnsigned(rdb, 1); // has projection uint32_t input_dim = vset->proj_input_size; uint32_t output_dim = vset->hnsw->vector_dim; RedisModule_SaveUnsigned(rdb, input_dim); @@ -1054,13 +1267,18 @@ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { // Save projection matrix as binary blob size_t matrix_size = sizeof(float) * input_dim * output_dim; RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size); - } else { - RedisModule_SaveUnsigned(rdb, 0); // no projection } hnswNode *node = vset->hnsw->head; while(node) { - RedisModule_SaveString(rdb, node->value); + struct vsetNodeVal *nv = node->value; + RedisModule_SaveString(rdb, nv->item); + if (vset->numattribs) { + if (nv->attrib) + RedisModule_SaveString(rdb, nv->attrib); + else + RedisModule_SaveStringBuffer(rdb, "", 0); + } hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node); RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size); RedisModule_SaveUnsigned(rdb, sn->params_count); @@ -1083,7 +1301,9 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { if (!vset) return NULL; /* Load projection matrix if present */ - uint32_t has_projection = RedisModule_LoadUnsigned(rdb); + uint32_t save_flags = RedisModule_LoadUnsigned(rdb); + int has_projection = save_flags & SAVE_FLAG_HAS_PROJMATRIX; + int has_attribs = save_flags & SAVE_FLAG_HAS_ATTRIBS; if (has_projection) { uint32_t input_dim = RedisModule_LoadUnsigned(rdb); uint32_t output_dim = dim; @@ -1105,6 +1325,16 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { while(elements--) { // Load associated string element. RedisModuleString *ele = RedisModule_LoadString(rdb); + RedisModuleString *attrib = NULL; + if (has_attribs) { + attrib = RedisModule_LoadString(rdb); + size_t attrlen; + RedisModule_StringPtrLen(attrib,&attrlen); + if (attrlen == 0) { + RedisModule_FreeString(NULL,attrib); + attrib = NULL; + } + } size_t vector_len; void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len); uint32_t vector_bytes = hnsw_quants_bytes(vset->hnsw); @@ -1120,12 +1350,16 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { for (uint32_t j = 0; j < params_count; j++) params[j] = RedisModule_LoadUnsigned(rdb); - hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, ele); + struct vsetNodeVal *nv = RedisModule_Alloc(sizeof(*nv)); + nv->item = ele; + nv->attrib = attrib; + hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, nv); if (node == NULL) { RedisModule_LogIOError(rdb,"warning", "Vector set node index loading error"); return NULL; // Loading error. } + if (nv->attrib) vset->numattribs++; RedisModule_DictSet(vset->dict,ele,node); RedisModule_Free(vector); RedisModule_Free(params); @@ -1279,6 +1513,14 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx, "VSETATTR", + VSETATTR_RedisCommand, "write fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VGETATTR", + VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, RedisModule_Realloc);