(root)/
strace-6.5/
src/
trie.c
       1  /*
       2   * Simple trie implementation for key-value mapping storage
       3   *
       4   * Copyright (c) 2020-2021 Ákos Uzonyi <uzonyi.akos@gmail.com>
       5   * All rights reserved.
       6   *
       7   * SPDX-License-Identifier: LGPL-2.1-or-later
       8   */
       9  
      10  #ifdef HAVE_CONFIG_H
      11  # include "config.h"
      12  #endif
      13  
      14  #include <stdlib.h>
      15  #include <stdio.h>
      16  
      17  #include "trie.h"
      18  #include "macros.h"
      19  #include "xmalloc.h"
      20  
      21  static const uint8_t ptr_sz_lg = (sizeof(void *) == 8 ? 6 : 5);
      22  
      23  /**
      24   * Returns lg2 of node size in bits for the specific level of the trie.
      25   */
      26  static uint8_t
      27  trie_get_node_size(struct trie *t, uint8_t depth)
      28  {
      29  	/* Last level contains data and we allow it having a different size */
      30  	if (depth == t->max_depth)
      31  		return t->data_block_key_bits + t->item_size_lg;
      32  	/* Last level of the tree can be smaller */
      33  	if (depth == t->max_depth - 1)
      34  		return (t->key_size - t->data_block_key_bits - 1) %
      35  		t->node_key_bits + 1 + ptr_sz_lg;
      36  
      37  	return t->node_key_bits + ptr_sz_lg;
      38  }
      39  
      40  /**
      41   * Provides starting offset of bits in key corresponding to the node index
      42   * at the specific level.
      43   */
      44  static uint8_t
      45  trie_get_node_bit_offs(struct trie *t, uint8_t depth)
      46  {
      47  	uint8_t offs;
      48  
      49  	if (depth == t->max_depth)
      50  		return 0;
      51  
      52  	offs = t->data_block_key_bits;
      53  
      54  	if (depth == t->max_depth - 1)
      55  		return offs;
      56  
      57  	/* data_block_size + remainder */
      58  	offs += trie_get_node_size(t, t->max_depth - 1) - ptr_sz_lg;
      59  	offs += (t->max_depth - depth - 2) * t->node_key_bits;
      60  
      61  	return offs;
      62  }
      63  
      64  struct trie *
      65  trie_create(uint8_t key_size, uint8_t item_size_lg, uint8_t node_key_bits,
      66              uint8_t data_block_key_bits, uint64_t empty_value)
      67  {
      68  	if (item_size_lg > 6)
      69  		return NULL;
      70  	if (key_size > 64)
      71  		return NULL;
      72  	if (node_key_bits < 1)
      73  		return NULL;
      74  	if (data_block_key_bits < 1 || data_block_key_bits > key_size)
      75  		return NULL;
      76  
      77  	struct trie *t = malloc(sizeof(*t));
      78  	if (!t)
      79  		return NULL;
      80  
      81  	t->fill_value = t->empty_value =
      82  			empty_value & MASK64_SAFE(BIT32(item_size_lg));
      83  	for (size_t i = 0; i < 6U - item_size_lg; i++)
      84  		t->fill_value |= t->fill_value << BIT32(item_size_lg + i);
      85  
      86  	t->data = NULL;
      87  	t->item_size_lg = item_size_lg;
      88  	t->node_key_bits = node_key_bits;
      89  	t->data_block_key_bits = data_block_key_bits;
      90  	t->key_size = key_size;
      91  	t->max_depth = (key_size - data_block_key_bits + node_key_bits - 1)
      92  		/ t->node_key_bits;
      93  
      94  	return t;
      95  }
      96  
      97  static void *
      98  trie_create_data_block(struct trie *t)
      99  {
     100  	uint8_t sz = t->data_block_key_bits + t->item_size_lg;
     101  	if (sz < 6)
     102  		sz = 6;
     103  
     104  	size_t count = BIT32(sz - 6);
     105  	uint64_t *data_block = xcalloc(count, 8);
     106  
     107  	for (size_t i = 0; i < count; i++)
     108  		data_block[i] = t->fill_value;
     109  
     110  	return data_block;
     111  }
     112  
     113  static uint64_t *
     114  trie_get_node(struct trie *t, uint64_t key, bool auto_create)
     115  {
     116  	void **cur_node = &(t->data);
     117  
     118  	if (t->key_size < 64 && key > MASK64(t->key_size))
     119  		return NULL;
     120  
     121  	for (uint8_t cur_depth = 0; cur_depth <= t->max_depth; cur_depth++) {
     122  		uint8_t offs = trie_get_node_bit_offs(t, cur_depth);
     123  		uint8_t sz = trie_get_node_size(t, cur_depth);
     124  
     125  		if (!*cur_node) {
     126  			if (!auto_create)
     127  				return NULL;
     128  
     129  			if (cur_depth == t->max_depth)
     130  				*cur_node = trie_create_data_block(t);
     131  			else
     132  				*cur_node = xcalloc(BIT64(sz), 1);
     133  		}
     134  
     135  		if (cur_depth == t->max_depth)
     136  			break;
     137  
     138  		size_t pos = (key >> offs) & MASK64(sz - ptr_sz_lg);
     139  		cur_node = (((void **) (*cur_node)) + pos);
     140  	}
     141  
     142  	return (uint64_t *) (*cur_node);
     143  }
     144  
     145  static void
     146  trie_data_block_calc_pos(struct trie *t, uint64_t key,
     147                           uint64_t *pos, uint64_t *mask, uint64_t *offs)
     148  {
     149  	uint64_t key_mask;
     150  
     151  	key_mask = MASK64(t->data_block_key_bits);
     152  	*pos = (key & key_mask) >> (6 - t->item_size_lg);
     153  
     154  	if (t->item_size_lg == 6) {
     155  		*offs = 0;
     156  		*mask = -1;
     157  		return;
     158  	}
     159  
     160  	key_mask = MASK64(6 - t->item_size_lg);
     161  	*offs = (key & key_mask) << t->item_size_lg;
     162  
     163  	*mask = MASK64_SAFE(BIT32(t->item_size_lg)) << *offs;
     164  }
     165  
     166  bool
     167  trie_set(struct trie *t, uint64_t key, uint64_t val)
     168  {
     169  	uint64_t *data = trie_get_node(t, key, true);
     170  	if (!data)
     171  		return false;
     172  
     173  	uint64_t pos, mask, offs;
     174  	trie_data_block_calc_pos(t, key, &pos, &mask, &offs);
     175  
     176  	data[pos] &= ~mask;
     177  	data[pos] |= (val << offs) & mask;
     178  
     179  	return true;
     180  }
     181  
     182  static uint64_t
     183  trie_data_block_get(struct trie *t, uint64_t *data, uint64_t key)
     184  {
     185  	if (!data)
     186  		return t->empty_value;
     187  
     188  	uint64_t pos, mask, offs;
     189  	trie_data_block_calc_pos(t, key, &pos, &mask, &offs);
     190  
     191  	return (data[pos] & mask) >> offs;
     192  }
     193  
     194  uint64_t
     195  trie_get(struct trie *b, uint64_t key)
     196  {
     197  	return trie_data_block_get(b, trie_get_node(b, key, false), key);
     198  }
     199  
     200  static uint64_t
     201  trie_iterate_keys_node(struct trie *t,
     202                         trie_iterate_fn fn, void *fn_data,
     203                         void *node, uint64_t start, uint64_t end,
     204                         uint8_t depth)
     205  {
     206  	if (start > end || !node)
     207  		return 0;
     208  
     209  	if (t->key_size < 64) {
     210  		uint64_t key_max = MASK64(t->key_size);
     211  		if (end > key_max)
     212  			end = key_max;
     213  	}
     214  
     215  	if (depth == t->max_depth) {
     216  		for (uint64_t i = start; i <= end; i++)
     217  			fn(fn_data, i, trie_data_block_get(t,
     218  				(uint64_t *) node, i));
     219  
     220  		return end - start + 1;
     221  	}
     222  
     223  	uint8_t parent_node_bit_off = depth == 0 ?
     224  		t->key_size :
     225  		trie_get_node_bit_offs(t, depth - 1);
     226  
     227  	uint64_t first_key_in_node = start & ~MASK64_SAFE(parent_node_bit_off);
     228  
     229  	uint8_t node_bit_off = trie_get_node_bit_offs(t, depth);
     230  	uint8_t node_key_bits = parent_node_bit_off - node_bit_off;
     231  	uint64_t mask = MASK64_SAFE(node_key_bits);
     232  	uint64_t start_index = (start >> node_bit_off) & mask;
     233  	uint64_t end_index = (end >> node_bit_off) & mask;
     234  	uint64_t child_key_count = BIT64(node_bit_off);
     235  
     236  	uint64_t count = 0;
     237  
     238  	for (uint64_t i = start_index; i <= end_index; i++) {
     239  		uint64_t child_start = first_key_in_node + i * child_key_count;
     240  		uint64_t child_end = first_key_in_node +
     241  			(i + 1) * child_key_count - 1;
     242  
     243  		if (child_start < start)
     244  			child_start = start;
     245  		if (child_end > end)
     246  			child_end = end;
     247  
     248  		count += trie_iterate_keys_node(t, fn, fn_data,
     249  			((void **) node)[i], child_start, child_end,
     250  			depth + 1);
     251  	}
     252  
     253  	return count;
     254  }
     255  
     256  uint64_t trie_iterate_keys(struct trie *t, uint64_t start, uint64_t end,
     257                             trie_iterate_fn fn, void *fn_data)
     258  {
     259  	return trie_iterate_keys_node(t, fn, fn_data, t->data,
     260  		start, end, 0);
     261  }
     262  
     263  static void
     264  trie_free_node(struct trie *t, void *node, uint8_t depth)
     265  {
     266  	if (!node)
     267  		return;
     268  
     269  	if (depth >= t->max_depth)
     270  		goto free_node;
     271  
     272  	size_t sz = BIT64(trie_get_node_size(t, depth) - ptr_sz_lg);
     273  	for (size_t i = 0; i < sz; i++)
     274  		trie_free_node(t, ((void **) node)[i], depth + 1);
     275  
     276  free_node:
     277  	free(node);
     278  }
     279  
     280  void
     281  trie_free(struct trie *t)
     282  {
     283  	trie_free_node(t, t->data, 0);
     284  	free(t);
     285  }