kern

annotate src/rbtree.c @ 73:b4b7198986a6

fixed a potential null dereference when deleting a bug in the redblack tree
author John Tsiombikas <nuclear@member.fsf.org>
date Sat, 15 Oct 2011 08:06:10 +0300
parents b45e2d5f0ae1
children
rev   line source
nuclear@68 1 #include <stdio.h>
nuclear@68 2 #include <stdlib.h>
nuclear@68 3 #include <string.h>
nuclear@68 4 #include "rbtree.h"
nuclear@69 5 #include "panic.h"
nuclear@68 6
nuclear@68 7 #define INT2PTR(x) ((void*)(x))
nuclear@68 8 #define PTR2INT(x) ((int)(x))
nuclear@68 9
nuclear@68 10 static int cmpaddr(void *ap, void *bp);
nuclear@68 11 static int cmpint(void *ap, void *bp);
nuclear@68 12
nuclear@68 13 static int count_nodes(struct rbnode *node);
nuclear@68 14 static void del_tree(struct rbnode *node, void (*delfunc)(struct rbnode*, void*), void *cls);
nuclear@68 15 static struct rbnode *insert(struct rbtree *rb, struct rbnode *tree, void *key, void *data);
nuclear@68 16 static struct rbnode *delete(struct rbtree *rb, struct rbnode *tree, void *key);
nuclear@68 17 static void traverse(struct rbnode *node, void (*func)(struct rbnode*, void*), void *cls);
nuclear@68 18
nuclear@68 19 struct rbtree *rb_create(rb_cmp_func_t cmp_func)
nuclear@68 20 {
nuclear@68 21 struct rbtree *rb;
nuclear@68 22
nuclear@68 23 if(!(rb = malloc(sizeof *rb))) {
nuclear@68 24 return 0;
nuclear@68 25 }
nuclear@68 26 if(rb_init(rb, cmp_func) == -1) {
nuclear@68 27 free(rb);
nuclear@68 28 return 0;
nuclear@68 29 }
nuclear@68 30 return rb;
nuclear@68 31 }
nuclear@68 32
nuclear@68 33 void rb_free(struct rbtree *rb)
nuclear@68 34 {
nuclear@68 35 rb_destroy(rb);
nuclear@68 36 free(rb);
nuclear@68 37 }
nuclear@68 38
nuclear@68 39
nuclear@68 40 int rb_init(struct rbtree *rb, rb_cmp_func_t cmp_func)
nuclear@68 41 {
nuclear@68 42 memset(rb, 0, sizeof *rb);
nuclear@68 43
nuclear@68 44 if(cmp_func == RB_KEY_INT) {
nuclear@68 45 rb->cmp = cmpint;
nuclear@68 46 } else if(cmp_func == RB_KEY_STRING) {
nuclear@68 47 rb->cmp = (rb_cmp_func_t)strcmp;
nuclear@68 48 } else {
nuclear@68 49 rb->cmp = cmpaddr;
nuclear@68 50 }
nuclear@68 51
nuclear@68 52 rb->alloc = malloc;
nuclear@68 53 rb->free = free;
nuclear@68 54 return 0;
nuclear@68 55 }
nuclear@68 56
nuclear@68 57 void rb_destroy(struct rbtree *rb)
nuclear@68 58 {
nuclear@68 59 del_tree(rb->root, rb->del, rb->del_cls);
nuclear@68 60 }
nuclear@68 61
nuclear@69 62 void rb_clear(struct rbtree *rb)
nuclear@69 63 {
nuclear@69 64 del_tree(rb->root, rb->del, rb->del_cls);
nuclear@69 65 rb->root = 0;
nuclear@69 66 }
nuclear@69 67
nuclear@69 68 int rb_copy(struct rbtree *dest, struct rbtree *src)
nuclear@69 69 {
nuclear@69 70 struct rbnode *node;
nuclear@69 71
nuclear@69 72 rb_clear(dest);
nuclear@69 73
nuclear@69 74 rb_begin(src);
nuclear@69 75 while((node = rb_next(src))) {
nuclear@69 76 if(rb_insert(dest, node->key, node->data) == -1) {
nuclear@69 77 return -1;
nuclear@69 78 }
nuclear@69 79 }
nuclear@69 80 return 0;
nuclear@69 81 }
nuclear@69 82
nuclear@68 83 void rb_set_allocator(struct rbtree *rb, rb_alloc_func_t alloc, rb_free_func_t free)
nuclear@68 84 {
nuclear@68 85 rb->alloc = alloc;
nuclear@68 86 rb->free = free;
nuclear@68 87 }
nuclear@68 88
nuclear@68 89
nuclear@68 90 void rb_set_compare_func(struct rbtree *rb, rb_cmp_func_t func)
nuclear@68 91 {
nuclear@68 92 rb->cmp = func;
nuclear@68 93 }
nuclear@68 94
nuclear@68 95 void rb_set_delete_func(struct rbtree *rb, rb_del_func_t func, void *cls)
nuclear@68 96 {
nuclear@68 97 rb->del = func;
nuclear@68 98 rb->del_cls = cls;
nuclear@68 99 }
nuclear@68 100
nuclear@68 101 int rb_size(struct rbtree *rb)
nuclear@68 102 {
nuclear@68 103 return count_nodes(rb->root);
nuclear@68 104 }
nuclear@68 105
nuclear@68 106 int rb_insert(struct rbtree *rb, void *key, void *data)
nuclear@68 107 {
nuclear@68 108 rb->root = insert(rb, rb->root, key, data);
nuclear@68 109 rb->root->red = 0;
nuclear@68 110 return 0;
nuclear@68 111 }
nuclear@68 112
nuclear@68 113 int rb_inserti(struct rbtree *rb, int key, void *data)
nuclear@68 114 {
nuclear@68 115 rb->root = insert(rb, rb->root, INT2PTR(key), data);
nuclear@68 116 rb->root->red = 0;
nuclear@68 117 return 0;
nuclear@68 118 }
nuclear@68 119
nuclear@68 120
nuclear@68 121 int rb_delete(struct rbtree *rb, void *key)
nuclear@68 122 {
nuclear@68 123 rb->root = delete(rb, rb->root, key);
nuclear@68 124 rb->root->red = 0;
nuclear@68 125 return 0;
nuclear@68 126 }
nuclear@68 127
nuclear@68 128 int rb_deletei(struct rbtree *rb, int key)
nuclear@68 129 {
nuclear@68 130 rb->root = delete(rb, rb->root, INT2PTR(key));
nuclear@68 131 rb->root->red = 0;
nuclear@68 132 return 0;
nuclear@68 133 }
nuclear@68 134
nuclear@68 135
nuclear@68 136 void *rb_find(struct rbtree *rb, void *key)
nuclear@68 137 {
nuclear@68 138 struct rbnode *node = rb->root;
nuclear@68 139
nuclear@68 140 while(node) {
nuclear@68 141 int cmp = rb->cmp(key, node->key);
nuclear@68 142 if(cmp == 0) {
nuclear@68 143 return node;
nuclear@68 144 }
nuclear@68 145 node = cmp < 0 ? node->left : node->right;
nuclear@68 146 }
nuclear@68 147 return 0;
nuclear@68 148 }
nuclear@68 149
nuclear@68 150 void *rb_findi(struct rbtree *rb, int key)
nuclear@68 151 {
nuclear@68 152 return rb_find(rb, INT2PTR(key));
nuclear@68 153 }
nuclear@68 154
nuclear@68 155
nuclear@68 156 void rb_foreach(struct rbtree *rb, void (*func)(struct rbnode*, void*), void *cls)
nuclear@68 157 {
nuclear@68 158 traverse(rb->root, func, cls);
nuclear@68 159 }
nuclear@68 160
nuclear@68 161
nuclear@68 162 struct rbnode *rb_root(struct rbtree *rb)
nuclear@68 163 {
nuclear@68 164 return rb->root;
nuclear@68 165 }
nuclear@68 166
nuclear@68 167 void rb_begin(struct rbtree *rb)
nuclear@68 168 {
nuclear@68 169 rb->rstack = 0;
nuclear@68 170 rb->iter = rb->root;
nuclear@68 171 }
nuclear@68 172
nuclear@68 173 #define push(sp, x) ((x)->next = (sp), (sp) = (x))
nuclear@68 174 #define pop(sp) ((sp) = (sp)->next)
nuclear@68 175 #define top(sp) (sp)
nuclear@68 176
nuclear@68 177 struct rbnode *rb_next(struct rbtree *rb)
nuclear@68 178 {
nuclear@68 179 struct rbnode *res = 0;
nuclear@68 180
nuclear@68 181 while(rb->rstack || rb->iter) {
nuclear@68 182 if(rb->iter) {
nuclear@68 183 push(rb->rstack, rb->iter);
nuclear@68 184 rb->iter = rb->iter->left;
nuclear@68 185 } else {
nuclear@68 186 rb->iter = top(rb->rstack);
nuclear@68 187 pop(rb->rstack);
nuclear@68 188 res = rb->iter;
nuclear@68 189 rb->iter = rb->iter->right;
nuclear@68 190 break;
nuclear@68 191 }
nuclear@68 192 }
nuclear@68 193 return res;
nuclear@68 194 }
nuclear@68 195
nuclear@68 196 void *rb_node_key(struct rbnode *node)
nuclear@68 197 {
nuclear@68 198 return node ? node->key : 0;
nuclear@68 199 }
nuclear@68 200
nuclear@68 201 int rb_node_keyi(struct rbnode *node)
nuclear@68 202 {
nuclear@68 203 return node ? PTR2INT(node->key) : 0;
nuclear@68 204 }
nuclear@68 205
nuclear@68 206 void *rb_node_data(struct rbnode *node)
nuclear@68 207 {
nuclear@68 208 return node ? node->data : 0;
nuclear@68 209 }
nuclear@68 210
nuclear@68 211 static int cmpaddr(void *ap, void *bp)
nuclear@68 212 {
nuclear@68 213 return ap < bp ? -1 : (ap > bp ? 1 : 0);
nuclear@68 214 }
nuclear@68 215
nuclear@68 216 static int cmpint(void *ap, void *bp)
nuclear@68 217 {
nuclear@68 218 return PTR2INT(ap) - PTR2INT(bp);
nuclear@68 219 }
nuclear@68 220
nuclear@68 221
nuclear@68 222 /* ---- left-leaning 2-3 red-black implementation ---- */
nuclear@68 223
nuclear@68 224 /* helper prototypes */
nuclear@68 225 static int is_red(struct rbnode *tree);
nuclear@68 226 static void color_flip(struct rbnode *tree);
nuclear@68 227 static struct rbnode *rot_left(struct rbnode *a);
nuclear@68 228 static struct rbnode *rot_right(struct rbnode *a);
nuclear@68 229 static struct rbnode *find_min(struct rbnode *tree);
nuclear@68 230 static struct rbnode *del_min(struct rbtree *rb, struct rbnode *tree);
nuclear@68 231 /*static struct rbnode *move_red_right(struct rbnode *tree);*/
nuclear@68 232 static struct rbnode *move_red_left(struct rbnode *tree);
nuclear@68 233 static struct rbnode *fix_up(struct rbnode *tree);
nuclear@68 234
nuclear@68 235 static int count_nodes(struct rbnode *node)
nuclear@68 236 {
nuclear@68 237 if(!node)
nuclear@68 238 return 0;
nuclear@68 239
nuclear@68 240 return 1 + count_nodes(node->left) + count_nodes(node->right);
nuclear@68 241 }
nuclear@68 242
nuclear@68 243 static void del_tree(struct rbnode *node, rb_del_func_t delfunc, void *cls)
nuclear@68 244 {
nuclear@68 245 if(!node)
nuclear@68 246 return;
nuclear@68 247
nuclear@68 248 del_tree(node->left, delfunc, cls);
nuclear@68 249 del_tree(node->right, delfunc, cls);
nuclear@68 250
nuclear@73 251 if(delfunc) {
nuclear@73 252 delfunc(node, cls);
nuclear@73 253 }
nuclear@68 254 free(node);
nuclear@68 255 }
nuclear@68 256
nuclear@68 257 static struct rbnode *insert(struct rbtree *rb, struct rbnode *tree, void *key, void *data)
nuclear@68 258 {
nuclear@68 259 int cmp;
nuclear@68 260
nuclear@68 261 if(!tree) {
nuclear@68 262 struct rbnode *node = rb->alloc(sizeof *node);
nuclear@69 263 if(!node) {
nuclear@69 264 panic("failed to allocate tree node\n");
nuclear@69 265 }
nuclear@68 266 node->red = 1;
nuclear@68 267 node->key = key;
nuclear@68 268 node->data = data;
nuclear@68 269 node->left = node->right = 0;
nuclear@68 270 return node;
nuclear@68 271 }
nuclear@68 272
nuclear@68 273 cmp = rb->cmp(key, tree->key);
nuclear@68 274
nuclear@68 275 if(cmp < 0) {
nuclear@68 276 tree->left = insert(rb, tree->left, key, data);
nuclear@68 277 } else if(cmp > 0) {
nuclear@68 278 tree->right = insert(rb, tree->right, key, data);
nuclear@68 279 } else {
nuclear@68 280 tree->data = data;
nuclear@68 281 }
nuclear@68 282
nuclear@68 283 /* fix right-leaning reds */
nuclear@68 284 if(is_red(tree->right)) {
nuclear@68 285 tree = rot_left(tree);
nuclear@68 286 }
nuclear@68 287 /* fix two reds in a row */
nuclear@68 288 if(is_red(tree->left) && is_red(tree->left->left)) {
nuclear@68 289 tree = rot_right(tree);
nuclear@68 290 }
nuclear@68 291
nuclear@68 292 /* if 4-node, split it by color inversion */
nuclear@68 293 if(is_red(tree->left) && is_red(tree->right)) {
nuclear@68 294 color_flip(tree);
nuclear@68 295 }
nuclear@68 296
nuclear@68 297 return tree;
nuclear@68 298 }
nuclear@68 299
nuclear@68 300 static struct rbnode *delete(struct rbtree *rb, struct rbnode *tree, void *key)
nuclear@68 301 {
nuclear@68 302 int cmp;
nuclear@68 303
nuclear@68 304 if(!tree) {
nuclear@68 305 return 0;
nuclear@68 306 }
nuclear@68 307
nuclear@68 308 cmp = rb->cmp(key, tree->key);
nuclear@68 309
nuclear@68 310 if(cmp < 0) {
nuclear@68 311 if(!is_red(tree->left) && !is_red(tree->left->left)) {
nuclear@68 312 tree = move_red_left(tree);
nuclear@68 313 }
nuclear@68 314 tree->left = delete(rb, tree->left, key);
nuclear@68 315 } else {
nuclear@68 316 /* need reds on the right */
nuclear@68 317 if(is_red(tree->left)) {
nuclear@68 318 tree = rot_right(tree);
nuclear@68 319 }
nuclear@68 320
nuclear@68 321 /* found it at the bottom (XXX what certifies left is null?) */
nuclear@68 322 if(cmp == 0 && !tree->right) {
nuclear@68 323 if(rb->del) {
nuclear@68 324 rb->del(tree, rb->del_cls);
nuclear@68 325 }
nuclear@68 326 rb->free(tree);
nuclear@68 327 return 0;
nuclear@68 328 }
nuclear@68 329
nuclear@68 330 if(!is_red(tree->right) && !is_red(tree->right->left)) {
nuclear@68 331 tree = move_red_left(tree);
nuclear@68 332 }
nuclear@68 333
nuclear@68 334 if(key == tree->key) {
nuclear@68 335 struct rbnode *rmin = find_min(tree->right);
nuclear@68 336 tree->key = rmin->key;
nuclear@68 337 tree->data = rmin->data;
nuclear@68 338 tree->right = del_min(rb, tree->right);
nuclear@68 339 } else {
nuclear@68 340 tree->right = delete(rb, tree->right, key);
nuclear@68 341 }
nuclear@68 342 }
nuclear@68 343
nuclear@68 344 return fix_up(tree);
nuclear@68 345 }
nuclear@68 346
nuclear@68 347 /*static struct rbnode *find(struct rbtree *rb, struct rbnode *node, void *key)
nuclear@68 348 {
nuclear@68 349 int cmp;
nuclear@68 350
nuclear@68 351 if(!node)
nuclear@68 352 return 0;
nuclear@68 353
nuclear@68 354 if((cmp = rb->cmp(key, node->key)) == 0) {
nuclear@68 355 return node;
nuclear@68 356 }
nuclear@68 357 return find(rb, cmp < 0 ? node->left : node->right, key);
nuclear@68 358 }*/
nuclear@68 359
nuclear@68 360 static void traverse(struct rbnode *node, void (*func)(struct rbnode*, void*), void *cls)
nuclear@68 361 {
nuclear@68 362 if(!node)
nuclear@68 363 return;
nuclear@68 364
nuclear@68 365 traverse(node->left, func, cls);
nuclear@68 366 func(node, cls);
nuclear@68 367 traverse(node->right, func, cls);
nuclear@68 368 }
nuclear@68 369
nuclear@68 370 /* helpers */
nuclear@68 371
nuclear@68 372 static int is_red(struct rbnode *tree)
nuclear@68 373 {
nuclear@68 374 return tree && tree->red;
nuclear@68 375 }
nuclear@68 376
nuclear@68 377 static void color_flip(struct rbnode *tree)
nuclear@68 378 {
nuclear@68 379 tree->red = !tree->red;
nuclear@68 380 tree->left->red = !tree->left->red;
nuclear@68 381 tree->right->red = !tree->right->red;
nuclear@68 382 }
nuclear@68 383
nuclear@68 384 static struct rbnode *rot_left(struct rbnode *a)
nuclear@68 385 {
nuclear@68 386 struct rbnode *b = a->right;
nuclear@68 387 a->right = b->left;
nuclear@68 388 b->left = a;
nuclear@68 389 b->red = a->red;
nuclear@68 390 a->red = 1;
nuclear@68 391 return b;
nuclear@68 392 }
nuclear@68 393
nuclear@68 394 static struct rbnode *rot_right(struct rbnode *a)
nuclear@68 395 {
nuclear@68 396 struct rbnode *b = a->left;
nuclear@68 397 a->left = b->right;
nuclear@68 398 b->right = a;
nuclear@68 399 b->red = a->red;
nuclear@68 400 a->red = 1;
nuclear@68 401 return b;
nuclear@68 402 }
nuclear@68 403
nuclear@68 404 static struct rbnode *find_min(struct rbnode *tree)
nuclear@68 405 {
nuclear@68 406 struct rbnode *node;
nuclear@68 407
nuclear@68 408 if(!tree)
nuclear@68 409 return 0;
nuclear@68 410
nuclear@68 411 while(node->left) {
nuclear@68 412 node = node->left;
nuclear@68 413 }
nuclear@68 414 return node;
nuclear@68 415 }
nuclear@68 416
nuclear@68 417 static struct rbnode *del_min(struct rbtree *rb, struct rbnode *tree)
nuclear@68 418 {
nuclear@68 419 if(!tree->left) {
nuclear@68 420 if(rb->del) {
nuclear@68 421 rb->del(tree->left, rb->del_cls);
nuclear@68 422 }
nuclear@68 423 rb->free(tree->left);
nuclear@68 424 return 0;
nuclear@68 425 }
nuclear@68 426
nuclear@68 427 /* make sure we've got red (3/4-nodes) at the left side so we can delete at the bottom */
nuclear@68 428 if(!is_red(tree->left) && !is_red(tree->left->left)) {
nuclear@68 429 tree = move_red_left(tree);
nuclear@68 430 }
nuclear@68 431 tree->left = del_min(rb, tree->left);
nuclear@68 432
nuclear@68 433 /* fix right-reds, red-reds, and split 4-nodes on the way up */
nuclear@68 434 return fix_up(tree);
nuclear@68 435 }
nuclear@68 436
nuclear@68 437 #if 0
nuclear@68 438 /* push a red link on this node to the right */
nuclear@68 439 static struct rbnode *move_red_right(struct rbnode *tree)
nuclear@68 440 {
nuclear@68 441 /* flipping it makes both children go red, so we have a red to the right */
nuclear@68 442 color_flip(tree);
nuclear@68 443
nuclear@68 444 /* if after the flip we've got a red-red situation to the left, fix it */
nuclear@68 445 if(is_red(tree->left->left)) {
nuclear@68 446 tree = rot_right(tree);
nuclear@68 447 color_flip(tree);
nuclear@68 448 }
nuclear@68 449 return tree;
nuclear@68 450 }
nuclear@68 451 #endif
nuclear@68 452
nuclear@68 453 /* push a red link on this node to the left */
nuclear@68 454 static struct rbnode *move_red_left(struct rbnode *tree)
nuclear@68 455 {
nuclear@68 456 /* flipping it makes both children go red, so we have a red to the left */
nuclear@68 457 color_flip(tree);
nuclear@68 458
nuclear@68 459 /* if after the flip we've got a red-red on the right-left, fix it */
nuclear@68 460 if(is_red(tree->right->left)) {
nuclear@68 461 tree->right = rot_right(tree->right);
nuclear@68 462 tree = rot_left(tree);
nuclear@68 463 color_flip(tree);
nuclear@68 464 }
nuclear@68 465 return tree;
nuclear@68 466 }
nuclear@68 467
nuclear@68 468 static struct rbnode *fix_up(struct rbnode *tree)
nuclear@68 469 {
nuclear@68 470 /* fix right-leaning */
nuclear@68 471 if(is_red(tree->right)) {
nuclear@68 472 tree = rot_left(tree);
nuclear@68 473 }
nuclear@68 474 /* change invalid red-red pairs into a proper 4-node */
nuclear@68 475 if(is_red(tree->left) && is_red(tree->left->left)) {
nuclear@68 476 tree = rot_right(tree);
nuclear@68 477 }
nuclear@68 478 /* split 4-nodes */
nuclear@68 479 if(is_red(tree->left) && is_red(tree->right)) {
nuclear@68 480 color_flip(tree);
nuclear@68 481 }
nuclear@68 482 return tree;
nuclear@68 483 }
nuclear@69 484
nuclear@69 485 void rb_dbg_print_tree(struct rbtree *tree)
nuclear@69 486 {
nuclear@69 487 struct rbnode *node;
nuclear@69 488
nuclear@69 489 rb_begin(tree);
nuclear@69 490 while((node = rb_next(tree))) {
nuclear@69 491 printf("%d ", rb_node_keyi(node));
nuclear@69 492 }
nuclear@69 493 printf("\n");
nuclear@69 494 }