+#include <asm/bug.h>
+#include <linux/rbtree_augmented.h>
#include "drbd_interval.h"
/**
}
/**
- * update_interval_end - recompute end of @node
+ * compute_subtree_last - compute end of @node
*
* The end of an interval is the highest (start + (size >> 9)) value of this
* node and of its children. Called for @node and its parents whenever the end
* may have changed.
*/
-static void
-update_interval_end(struct rb_node *node, void *__unused)
+static inline sector_t
+compute_subtree_last(struct drbd_interval *node)
{
- struct drbd_interval *this = rb_entry(node, struct drbd_interval, rb);
- sector_t end;
+ sector_t max = node->sector + (node->size >> 9);
- end = this->sector + (this->size >> 9);
- if (node->rb_left) {
- sector_t left = interval_end(node->rb_left);
- if (left > end)
- end = left;
+ if (node->rb.rb_left) {
+ sector_t left = interval_end(node->rb.rb_left);
+ if (left > max)
+ max = left;
+ }
+ if (node->rb.rb_right) {
+ sector_t right = interval_end(node->rb.rb_right);
+ if (right > max)
+ max = right;
}
- if (node->rb_right) {
- sector_t right = interval_end(node->rb_right);
- if (right > end)
- end = right;
+ return max;
+}
+
+static void augment_propagate(struct rb_node *rb, struct rb_node *stop)
+{
+ while (rb != stop) {
+ struct drbd_interval *node = rb_entry(rb, struct drbd_interval, rb);
+ sector_t subtree_last = compute_subtree_last(node);
+ if (node->end == subtree_last)
+ break;
+ node->end = subtree_last;
+ rb = rb_parent(&node->rb);
}
- this->end = end;
}
+static void augment_copy(struct rb_node *rb_old, struct rb_node *rb_new)
+{
+ struct drbd_interval *old = rb_entry(rb_old, struct drbd_interval, rb);
+ struct drbd_interval *new = rb_entry(rb_new, struct drbd_interval, rb);
+
+ new->end = old->end;
+}
+
+static void augment_rotate(struct rb_node *rb_old, struct rb_node *rb_new)
+{
+ struct drbd_interval *old = rb_entry(rb_old, struct drbd_interval, rb);
+ struct drbd_interval *new = rb_entry(rb_new, struct drbd_interval, rb);
+
+ new->end = old->end;
+ old->end = compute_subtree_last(old);
+}
+
+static const struct rb_augment_callbacks augment_callbacks = {
+ augment_propagate,
+ augment_copy,
+ augment_rotate,
+};
+
/**
* drbd_insert_interval - insert a new interval into a tree
*/
}
rb_link_node(&this->rb, parent, new);
- rb_insert_color(&this->rb, root);
- rb_augment_insert(&this->rb, update_interval_end, NULL);
+ rb_insert_augmented(&this->rb, root, &augment_callbacks);
return true;
}
else if (interval > here)
node = node->rb_right;
else
- return interval->sector == sector;
+ return true;
}
return false;
}
void
drbd_remove_interval(struct rb_root *root, struct drbd_interval *this)
{
- struct rb_node *deepest;
-
- deepest = rb_augment_erase_begin(&this->rb);
- rb_erase(&this->rb, root);
- rb_augment_erase_end(deepest, update_interval_end, NULL);
+ rb_erase_augmented(&this->rb, root, &augment_callbacks);
}
/**
* @sector: start sector
* @size: size, aligned to 512 bytes
*
- * Returns the interval overlapping with [sector, sector + size), or NULL.
- * When there is more than one overlapping interval in the tree, the interval
- * with the lowest start sector is returned.
+ * Returns an interval overlapping with [sector, sector + size), or NULL if
+ * there is none. When there is more than one overlapping interval in the
+ * tree, the interval with the lowest start sector is returned, and all other
+ * overlapping intervals will be on the right side of the tree, reachable with
+ * rb_next().
*/
struct drbd_interval *
drbd_find_overlap(struct rb_root *root, sector_t sector, unsigned int size)
}
return overlap;
}
+
+struct drbd_interval *
+drbd_next_overlap(struct drbd_interval *i, sector_t sector, unsigned int size)
+{
+ sector_t end = sector + (size >> 9);
+ struct rb_node *node;
+
+ for (;;) {
+ node = rb_next(&i->rb);
+ if (!node)
+ return NULL;
+ i = rb_entry(node, struct drbd_interval, rb);
+ if (i->sector >= end)
+ return NULL;
+ if (sector < i->sector + (i->size >> 9))
+ return i;
+ }
+}