summaryrefslogtreecommitdiff
path: root/tree.sml
blob: 678ebe8e9af5c1693a2e031d9206ae47153cbe0e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
structure Tree: TREE = struct
  datatype ('k, 'v) t = Node of 'k * 'v * ('k, 'v) t * ('k, 'v) t | Empty

  type 'k cmp = 'k -> 'k -> order

  val empty = Empty

  fun insert _ Empty k v = (NONE, Node (k, v, Empty, Empty))
    | insert cmp (Node (k', v', left, right)) k v =
        case cmp k k' of
          LESS =>
          let
            val (res, left) = insert cmp left k v
          in
            (res, Node (k', v', left, right))
          end
        | EQUAL => (SOME v', Node (k, v, left, right))
        | GREATER =>
          let
            val (res, right) = insert cmp right k v
          in
            (res, Node (k', v', left, right))
          end

  fun delete _ Empty _ = (NONE, Empty)
    | delete cmp (Node (k', v', left, right)) k =
        case cmp k k' of
          LESS =>
          let
            val (res, left) = delete cmp left k
          in
            (res, Node (k', v', left, right))
          end
        | GREATER =>
          let
            val (res, right) = delete cmp right k
          in
            (res, Node (k', v', left, right))
          end
        | EQUAL => (
            case (left, right) of
              (Empty, Node _) => (SOME v', right)
            | (Node _, Empty) => (SOME v', left)
            | (Empty, Empty) => (SOME v', Empty)
            | (Node _, Node _) =>
              let
                fun deleteRightmost Empty = raise Unreachable
                  | deleteRightmost (Node (k, v, left, Empty)) =
                    ((k, v), left)
                  | deleteRightmost (Node (k, v, left, right)) =
                  let
                    val (p, right) = deleteRightmost right
                  in
                    (p, Node (k, v, left, right))
                  end

                val ((k, v), left) = deleteRightmost left
              in
                (SOME v', Node (k, v, left, right))
              end
        )

  fun lookup _ Empty _ = NONE
    | lookup cmp (Node (k', v', left, right)) k =
        case cmp k k' of
          LESS => lookup cmp left k
        | EQUAL => SOME v'
        | GREATER => lookup cmp right k

  datatype ('k, 'v) arc =
    Left of 'k * 'v * ('k, 'v) t |
    Right of 'k * 'v * ('k, 'v) t

  fun assemble buf n =
  let
    fun assemble' (Left (k, v, right) :: tail) tree =
      assemble' tail (Node (k, v, tree, right))
      | assemble' (Right (k, v, left) :: tail) tree =
      assemble' tail (Node (k, v, left, tree))
      | assemble' [] tree = tree
  in
    assemble' buf n
  end

  fun lookup' buf _ Empty k f =
  let
    val (res, newV) = f NONE
  in
    (res, assemble buf
        (case newV of
             NONE => Empty
           | SOME v => Node (k, v, Empty, Empty)))
  end
    | lookup' buf cmp (T as Node (k', v', left, right)) k f =
      case cmp k k' of
        LESS => lookup' (Left (k', v', right) :: buf) cmp left k f
      | GREATER => lookup' (Right (k', v', left) :: buf) cmp right k f
      | EQUAL =>
      let
        val (res, newV) = f $ SOME v'
      in
        case newV of
           NONE => (res, T)
         | SOME v => (res, assemble buf (Node (k', v, left, right)))
      end

  fun lookup2 cmp t k f = lookup' [] cmp t k f

  fun print t key2str value2str =
  let
    fun Pkey z = bindWith2str key2str z
    fun Pvalue z = bindWith2str value2str z

    fun print' off Empty = printf R off `"()\n" %
      | print' off (Node (k, v, left, right)) = (
        printf R off `"(" Pkey k `", " Pvalue v `"\n";
        print' (off + 1) left;
        print' (off + 1) right;
        printf R off `")\n" %
      )
  in
    print' 0 t
  end

  fun size Empty = 0
    | size (Node(_, _, l, r)) = 1 + size l + size r
end