/* Java implementation of red-black trees as described in
**   Alternatives to Two Classic Data Structures
**   Chris Okasaki
**   SIGCSE 2005
**
** Implements a set of integers.  Can easily be adapted to other
** key types, or to other abstractions such as bags/dictionaries/etc.
*/

public class RedBlack {

   public RedBlack() {
      root = null;
   }

   public void insert(int key) {
      root = insert(key, root);
   }

   public boolean member(int key) {
      return member(key, root);
   }

   /***** Implementation details *****/

   private static final boolean RED = true;
   private static final boolean BLACK = false;

   private static class Tree {
      int key;
      boolean color;
      Tree left;
      Tree right;
   }

   private Tree root;

   private static Tree insert(int key, Tree tree) {
      tree = ins(key, tree);
      tree.color = BLACK;  // always recolor root black
      return tree;
   }

   private static Tree ins(int key, Tree tree) {
      if (tree == null) {
         tree = new Tree();
         tree.key = key;
         tree.color = RED;
         tree.left = null;
         tree.right = null;
      }
      else if (key < tree.key) {
         tree.left = ins(key, tree.left);
      }
      else if (key > tree.key) {
         tree.right = ins(key, tree.right);
      }
      else {
         return tree;  // key is already in tree
      }

      // check for red child and red grandchild
      if (isRed(tree.left) && isRed(tree.left.left)) {
         //       z           y
         //      / \         / \
         //     y   D  ==>  /   \
         //    / \         x     z
         //   x   C       / \   / \
         //  / \         A   B C   D
         // A   B
         tree = balance(tree.left.left, tree.left, tree,              // x,y,z
                        tree.left.left.right, tree.left.right);       // B,C
      }
      else if (isRed(tree.left) && isRed(tree.left.right)) {
         //       z           y
         //      / \         / \
         //     x   D  ==>  /   \
         //    / \         x     z
         //   A   y       / \   / \
         //      / \     A   B C   D
         //     B   C
         tree = balance(tree.left, tree.left.right, tree,             // x,y,z
                        tree.left.right.left, tree.left.right.right); // B,C
      }
      else if (isRed(tree.right) && isRed(tree.right.left)) {
         //     x             y
         //    / \           / \
         //   A   z    ==>  /   \
         //      / \       x     z
         //     y   D     / \   / \
         //    / \       A   B C   D
         //   B   C
         tree = balance(tree, tree.right.left, tree.right,            // x,y,z
                        tree.right.left.left, tree.right.left.right); // B,C
      }
      else if (isRed(tree.right) && isRed(tree.right.right)) {
         //   x               y
         //  / \             / \
         // A   y      ==>  /   \
         //    / \         x     z
         //   B   z       / \   / \
         //      / \     A   B C   D
         //     C   D
         tree = balance(tree, tree.right, tree.right.right,           // x,y,z
                        tree.right.left, tree.right.right.left);      // B,C
      }

      return tree;
   }

   private static Tree balance(Tree x, Tree y, Tree z, Tree B, Tree C) {
      // Rearrange/recolor the tree as
      //       y      <== red
      //      / \
      //     /   \
      //    x     z   <== both black
      //   / \   / \
      //  A   B C   D
      //
      // Note: A and D are not passed in because already in the right place

      x.right = B;
      y.left = x;
      y.right = z;
      z.left = C;
      x.color = BLACK;
      y.color = RED;
      z.color = BLACK;
      return y;
   }

   private static boolean isRed(Tree tree) {
      return tree != null && tree.color == RED;
   }

   private static boolean member(int key, Tree tree) {
      while (tree != null) {
         if (key < tree.key) {
            tree = tree.left;
         }
         else if (key > tree.key) {
            tree = tree.right;
         }
         else { // key == tree.key
            return true;
         }
      }

      return false;
   }
}
