1 /*
2 * Simple trie implementation for key-value mapping storage
3 *
4 * Copyright (c) 2020-2021 Ákos Uzonyi <uzonyi.akos@gmail.com>
5 * All rights reserved.
6 *
7 * SPDX-License-Identifier: LGPL-2.1-or-later
8 */
9
10 #ifdef HAVE_CONFIG_H
11 # include "config.h"
12 #endif
13
14 #include <stdlib.h>
15 #include <stdio.h>
16
17 #include "trie.h"
18 #include "macros.h"
19 #include "xmalloc.h"
20
21 static const uint8_t ptr_sz_lg = (sizeof(void *) == 8 ? 6 : 5);
22
23 /**
24 * Returns lg2 of node size in bits for the specific level of the trie.
25 */
26 static uint8_t
27 trie_get_node_size(struct trie *t, uint8_t depth)
28 {
29 /* Last level contains data and we allow it having a different size */
30 if (depth == t->max_depth)
31 return t->data_block_key_bits + t->item_size_lg;
32 /* Last level of the tree can be smaller */
33 if (depth == t->max_depth - 1)
34 return (t->key_size - t->data_block_key_bits - 1) %
35 t->node_key_bits + 1 + ptr_sz_lg;
36
37 return t->node_key_bits + ptr_sz_lg;
38 }
39
40 /**
41 * Provides starting offset of bits in key corresponding to the node index
42 * at the specific level.
43 */
44 static uint8_t
45 trie_get_node_bit_offs(struct trie *t, uint8_t depth)
46 {
47 uint8_t offs;
48
49 if (depth == t->max_depth)
50 return 0;
51
52 offs = t->data_block_key_bits;
53
54 if (depth == t->max_depth - 1)
55 return offs;
56
57 /* data_block_size + remainder */
58 offs += trie_get_node_size(t, t->max_depth - 1) - ptr_sz_lg;
59 offs += (t->max_depth - depth - 2) * t->node_key_bits;
60
61 return offs;
62 }
63
64 struct trie *
65 trie_create(uint8_t key_size, uint8_t item_size_lg, uint8_t node_key_bits,
66 uint8_t data_block_key_bits, uint64_t empty_value)
67 {
68 if (item_size_lg > 6)
69 return NULL;
70 if (key_size > 64)
71 return NULL;
72 if (node_key_bits < 1)
73 return NULL;
74 if (data_block_key_bits < 1 || data_block_key_bits > key_size)
75 return NULL;
76
77 struct trie *t = malloc(sizeof(*t));
78 if (!t)
79 return NULL;
80
81 t->fill_value = t->empty_value =
82 empty_value & MASK64_SAFE(BIT32(item_size_lg));
83 for (size_t i = 0; i < 6U - item_size_lg; i++)
84 t->fill_value |= t->fill_value << BIT32(item_size_lg + i);
85
86 t->data = NULL;
87 t->item_size_lg = item_size_lg;
88 t->node_key_bits = node_key_bits;
89 t->data_block_key_bits = data_block_key_bits;
90 t->key_size = key_size;
91 t->max_depth = (key_size - data_block_key_bits + node_key_bits - 1)
92 / t->node_key_bits;
93
94 return t;
95 }
96
97 static void *
98 trie_create_data_block(struct trie *t)
99 {
100 uint8_t sz = t->data_block_key_bits + t->item_size_lg;
101 if (sz < 6)
102 sz = 6;
103
104 size_t count = BIT32(sz - 6);
105 uint64_t *data_block = xcalloc(count, 8);
106
107 for (size_t i = 0; i < count; i++)
108 data_block[i] = t->fill_value;
109
110 return data_block;
111 }
112
113 static uint64_t *
114 trie_get_node(struct trie *t, uint64_t key, bool auto_create)
115 {
116 void **cur_node = &(t->data);
117
118 if (t->key_size < 64 && key > MASK64(t->key_size))
119 return NULL;
120
121 for (uint8_t cur_depth = 0; cur_depth <= t->max_depth; cur_depth++) {
122 uint8_t offs = trie_get_node_bit_offs(t, cur_depth);
123 uint8_t sz = trie_get_node_size(t, cur_depth);
124
125 if (!*cur_node) {
126 if (!auto_create)
127 return NULL;
128
129 if (cur_depth == t->max_depth)
130 *cur_node = trie_create_data_block(t);
131 else
132 *cur_node = xcalloc(BIT64(sz), 1);
133 }
134
135 if (cur_depth == t->max_depth)
136 break;
137
138 size_t pos = (key >> offs) & MASK64(sz - ptr_sz_lg);
139 cur_node = (((void **) (*cur_node)) + pos);
140 }
141
142 return (uint64_t *) (*cur_node);
143 }
144
145 static void
146 trie_data_block_calc_pos(struct trie *t, uint64_t key,
147 uint64_t *pos, uint64_t *mask, uint64_t *offs)
148 {
149 uint64_t key_mask;
150
151 key_mask = MASK64(t->data_block_key_bits);
152 *pos = (key & key_mask) >> (6 - t->item_size_lg);
153
154 if (t->item_size_lg == 6) {
155 *offs = 0;
156 *mask = -1;
157 return;
158 }
159
160 key_mask = MASK64(6 - t->item_size_lg);
161 *offs = (key & key_mask) << t->item_size_lg;
162
163 *mask = MASK64_SAFE(BIT32(t->item_size_lg)) << *offs;
164 }
165
166 bool
167 trie_set(struct trie *t, uint64_t key, uint64_t val)
168 {
169 uint64_t *data = trie_get_node(t, key, true);
170 if (!data)
171 return false;
172
173 uint64_t pos, mask, offs;
174 trie_data_block_calc_pos(t, key, &pos, &mask, &offs);
175
176 data[pos] &= ~mask;
177 data[pos] |= (val << offs) & mask;
178
179 return true;
180 }
181
182 static uint64_t
183 trie_data_block_get(struct trie *t, uint64_t *data, uint64_t key)
184 {
185 if (!data)
186 return t->empty_value;
187
188 uint64_t pos, mask, offs;
189 trie_data_block_calc_pos(t, key, &pos, &mask, &offs);
190
191 return (data[pos] & mask) >> offs;
192 }
193
194 uint64_t
195 trie_get(struct trie *b, uint64_t key)
196 {
197 return trie_data_block_get(b, trie_get_node(b, key, false), key);
198 }
199
200 static uint64_t
201 trie_iterate_keys_node(struct trie *t,
202 trie_iterate_fn fn, void *fn_data,
203 void *node, uint64_t start, uint64_t end,
204 uint8_t depth)
205 {
206 if (start > end || !node)
207 return 0;
208
209 if (t->key_size < 64) {
210 uint64_t key_max = MASK64(t->key_size);
211 if (end > key_max)
212 end = key_max;
213 }
214
215 if (depth == t->max_depth) {
216 for (uint64_t i = start; i <= end; i++)
217 fn(fn_data, i, trie_data_block_get(t,
218 (uint64_t *) node, i));
219
220 return end - start + 1;
221 }
222
223 uint8_t parent_node_bit_off = depth == 0 ?
224 t->key_size :
225 trie_get_node_bit_offs(t, depth - 1);
226
227 uint64_t first_key_in_node = start & ~MASK64_SAFE(parent_node_bit_off);
228
229 uint8_t node_bit_off = trie_get_node_bit_offs(t, depth);
230 uint8_t node_key_bits = parent_node_bit_off - node_bit_off;
231 uint64_t mask = MASK64_SAFE(node_key_bits);
232 uint64_t start_index = (start >> node_bit_off) & mask;
233 uint64_t end_index = (end >> node_bit_off) & mask;
234 uint64_t child_key_count = BIT64(node_bit_off);
235
236 uint64_t count = 0;
237
238 for (uint64_t i = start_index; i <= end_index; i++) {
239 uint64_t child_start = first_key_in_node + i * child_key_count;
240 uint64_t child_end = first_key_in_node +
241 (i + 1) * child_key_count - 1;
242
243 if (child_start < start)
244 child_start = start;
245 if (child_end > end)
246 child_end = end;
247
248 count += trie_iterate_keys_node(t, fn, fn_data,
249 ((void **) node)[i], child_start, child_end,
250 depth + 1);
251 }
252
253 return count;
254 }
255
256 uint64_t trie_iterate_keys(struct trie *t, uint64_t start, uint64_t end,
257 trie_iterate_fn fn, void *fn_data)
258 {
259 return trie_iterate_keys_node(t, fn, fn_data, t->data,
260 start, end, 0);
261 }
262
263 static void
264 trie_free_node(struct trie *t, void *node, uint8_t depth)
265 {
266 if (!node)
267 return;
268
269 if (depth >= t->max_depth)
270 goto free_node;
271
272 size_t sz = BIT64(trie_get_node_size(t, depth) - ptr_sz_lg);
273 for (size_t i = 0; i < sz; i++)
274 trie_free_node(t, ((void **) node)[i], depth + 1);
275
276 free_node:
277 free(node);
278 }
279
280 void
281 trie_free(struct trie *t)
282 {
283 trie_free_node(t, t->data, 0);
284 free(t);
285 }