summaryrefslogtreecommitdiff
path: root/tree.sml
blob: ee82485941756a384e7b598b7b64b3f3c697a364 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
structure Tree: TREE = struct
  datatype ('k, 'v) t = Node of 'k * 'v * ('k, 'v) t * ('k, 'v) t | Empty

  type 'k cmp = 'k -> 'k -> order

  val empty = Empty

  fun insert _ Empty k v = (NONE, Node (k, v, Empty, Empty))
    | insert cmp (Node (k', v', left, right)) k v =
        case cmp k k' of
          LESS =>
          let
            val (res, left) = insert cmp left k v
          in
            (res, Node (k', v', left, right))
          end
        | EQUAL => (SOME v', Node (k, v, left, right))
        | GREATER =>
          let
            val (res, right) = insert cmp right k v
          in
            (res, Node (k', v', left, right))
          end

  fun delete _ Empty _ = (NONE, Empty)
    | delete cmp (Node (k', v', left, right)) k =
        case cmp k k' of
          LESS =>
          let
            val (res, left) = delete cmp left k
          in
            (res, Node (k', v', left, right))
          end
        | GREATER =>
          let
            val (res, right) = delete cmp right k
          in
            (res, Node (k', v', left, right))
          end
        | EQUAL => (
            case (left, right) of
              (Empty, Node _) => (SOME v', right)
            | (Node _, Empty) => (SOME v', left)
            | (Empty, Empty) => (SOME v', Empty)
            | (Node _, Node _) =>
              let
                fun deleteRightmost Empty = raise Unreachable
                  | deleteRightmost (Node (k, v, left, Empty)) =
                    ((k, v), left)
                  | deleteRightmost (Node (k, v, left, right)) =
                  let
                    val (p, right) = deleteRightmost right
                  in
                    (p, Node (k, v, left, right))
                  end

                val ((k, v), left) = deleteRightmost left
              in
                (SOME v', Node (k, v, left, right))
              end
        )

  fun lookup _ Empty _ = NONE
    | lookup cmp (Node (k', v', left, right)) k =
        case cmp k k' of
          LESS => lookup cmp left k
        | EQUAL => SOME v'
        | GREATER => lookup cmp right k

  datatype ('k, 'v) arc =
    Left of 'k * 'v * ('k, 'v) t |
    Right of 'k * 'v * ('k, 'v) t

  fun assemble buf n =
  let
    fun assemble' (Left (k, v, right) :: tail) tree =
      assemble' tail (Node (k, v, tree, right))
      | assemble' (Right (k, v, left) :: tail) tree =
      assemble' tail (Node (k, v, left, tree))
      | assemble' [] tree = tree
  in
    assemble' buf n
  end

  (* f accepts previous value (NONE if not present) and returns
   * (res, v' option). Res will be returned by lookup2. 'v if present, will
   * be new value in place of the old one
   *)
  fun lookup' buf _ Empty k f =
  let
    val (res, newV) = f NONE
  in
    (res, assemble buf
        (case newV of
             NONE => Empty
           | SOME v => Node (k, v, Empty, Empty)))
  end
    | lookup' buf cmp (T as Node (k', v', left, right)) k f =
      case cmp k k' of
        LESS => lookup' (Left (k', v', right) :: buf) cmp left k f
      | GREATER => lookup' (Right (k', v', left) :: buf) cmp right k f
      | EQUAL =>
      let
        val (res, newV) = f $ SOME v'
      in
        case newV of
           NONE => (res, T)
         | SOME v => (res, assemble buf (Node (k', v, left, right)))
      end

  fun lookup2 cmp t k f = lookup' [] cmp t k f

  fun print t key2str value2str =
  let
    fun Pkey z = bindWith2str key2str z
    fun Pvalue z = bindWith2str value2str z

    fun print' off Empty = printf R off `"()\n" %
      | print' off (Node (k, v, left, right)) = (
        printf R off `"(" Pkey k `", " Pvalue v `"\n";
        print' (off + 1) left;
        print' (off + 1) right;
        printf R off `")\n" %
      )
  in
    print' 0 t
  end

  fun size Empty = 0
    | size (Node(_, _, l, r)) = 1 + size l + size r
end