1 /*
2  * Copyright (C) 2010 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.common.collect.testing;
18 
19 import java.io.Serializable;
20 import java.util.AbstractSet;
21 import java.util.Collection;
22 import java.util.Comparator;
23 import java.util.Iterator;
24 import java.util.Map;
25 import java.util.NavigableMap;
26 import java.util.NavigableSet;
27 import java.util.Set;
28 import java.util.SortedMap;
29 import java.util.TreeMap;
30 
31 /**
32  * A wrapper around {@code TreeMap} that aggressively checks to see if keys are
33  * mutually comparable. This implementation passes the navigable map test
34  * suites.
35  *
36  * @author Louis Wasserman
37  */
38 public final class SafeTreeMap<K, V>
39     implements Serializable, NavigableMap<K, V> {
40   @SuppressWarnings("unchecked")
41   private static final Comparator<Object> NATURAL_ORDER = new Comparator<Object>() {
42     @Override public int compare(Object o1, Object o2) {
43       return ((Comparable<Object>) o1).compareTo(o2);
44     }
45   };
46   private final NavigableMap<K, V> delegate;
47 
SafeTreeMap()48   public SafeTreeMap() {
49     this(new TreeMap<K, V>());
50   }
51 
SafeTreeMap(Comparator<? super K> comparator)52   public SafeTreeMap(Comparator<? super K> comparator) {
53     this(new TreeMap<K, V>(comparator));
54   }
55 
SafeTreeMap(Map<? extends K, ? extends V> map)56   public SafeTreeMap(Map<? extends K, ? extends V> map) {
57     this(new TreeMap<K, V>(map));
58   }
59 
SafeTreeMap(SortedMap<K, ? extends V> map)60   public SafeTreeMap(SortedMap<K, ? extends V> map) {
61     this(new TreeMap<K, V>(map));
62   }
63 
SafeTreeMap(NavigableMap<K, V> delegate)64   private SafeTreeMap(NavigableMap<K, V> delegate) {
65     this.delegate = delegate;
66     if (delegate == null) {
67       throw new NullPointerException();
68     }
69     for (K k : keySet()) {
70       checkValid(k);
71     }
72   }
73 
ceilingEntry(K key)74   @Override public Entry<K, V> ceilingEntry(K key) {
75     return delegate.ceilingEntry(checkValid(key));
76   }
77 
ceilingKey(K key)78   @Override public K ceilingKey(K key) {
79     return delegate.ceilingKey(checkValid(key));
80   }
81 
clear()82   @Override public void clear() {
83     delegate.clear();
84   }
85 
86   @SuppressWarnings("unchecked")
comparator()87   @Override public Comparator<? super K> comparator() {
88     Comparator<? super K> comparator = delegate.comparator();
89     if (comparator == null) {
90       comparator = (Comparator<? super K>) NATURAL_ORDER;
91     }
92     return comparator;
93   }
94 
containsKey(Object key)95   @Override public boolean containsKey(Object key) {
96     try {
97       return delegate.containsKey(checkValid(key));
98     } catch (NullPointerException e) {
99       return false;
100     } catch (ClassCastException e) {
101       return false;
102     }
103   }
104 
containsValue(Object value)105   @Override public boolean containsValue(Object value) {
106     return delegate.containsValue(value);
107   }
108 
descendingKeySet()109   @Override public NavigableSet<K> descendingKeySet() {
110     return delegate.descendingKeySet();
111   }
112 
descendingMap()113   @Override public NavigableMap<K, V> descendingMap() {
114     return new SafeTreeMap<K, V>(delegate.descendingMap());
115   }
116 
entrySet()117   @Override public Set<Entry<K, V>> entrySet() {
118     return new AbstractSet<Entry<K, V>>() {
119       private Set<Entry<K, V>> delegate() {
120         return delegate.entrySet();
121       }
122 
123       @Override
124       public boolean contains(Object object) {
125         try {
126           return delegate().contains(object);
127         } catch (NullPointerException e) {
128           return false;
129         } catch (ClassCastException e) {
130           return false;
131         }
132       }
133 
134       @Override
135       public Iterator<Entry<K, V>> iterator() {
136         return delegate().iterator();
137       }
138 
139       @Override
140       public int size() {
141         return delegate().size();
142       }
143 
144       @Override
145       public boolean remove(Object o) {
146         return delegate().remove(o);
147       }
148 
149       @Override
150       public void clear() {
151         delegate().clear();
152       }
153     };
154   }
155 
156   @Override public Entry<K, V> firstEntry() {
157     return delegate.firstEntry();
158   }
159 
160   @Override public K firstKey() {
161     return delegate.firstKey();
162   }
163 
164   @Override public Entry<K, V> floorEntry(K key) {
165     return delegate.floorEntry(checkValid(key));
166   }
167 
168   @Override public K floorKey(K key) {
169     return delegate.floorKey(checkValid(key));
170   }
171 
172   @Override public V get(Object key) {
173     return delegate.get(checkValid(key));
174   }
175 
176   @Override public SortedMap<K, V> headMap(K toKey) {
177     return headMap(toKey, false);
178   }
179 
180   @Override public NavigableMap<K, V> headMap(K toKey, boolean inclusive) {
181     return new SafeTreeMap<K, V>(
182         delegate.headMap(checkValid(toKey), inclusive));
183   }
184 
185   @Override public Entry<K, V> higherEntry(K key) {
186     return delegate.higherEntry(checkValid(key));
187   }
188 
189   @Override public K higherKey(K key) {
190     return delegate.higherKey(checkValid(key));
191   }
192 
193   @Override public boolean isEmpty() {
194     return delegate.isEmpty();
195   }
196 
197   @Override public NavigableSet<K> keySet() {
198     return navigableKeySet();
199   }
200 
201   @Override public Entry<K, V> lastEntry() {
202     return delegate.lastEntry();
203   }
204 
205   @Override public K lastKey() {
206     return delegate.lastKey();
207   }
208 
209   @Override public Entry<K, V> lowerEntry(K key) {
210     return delegate.lowerEntry(checkValid(key));
211   }
212 
213   @Override public K lowerKey(K key) {
214     return delegate.lowerKey(checkValid(key));
215   }
216 
217   @Override public NavigableSet<K> navigableKeySet() {
218     return delegate.navigableKeySet();
219   }
220 
221   @Override public Entry<K, V> pollFirstEntry() {
222     return delegate.pollFirstEntry();
223   }
224 
225   @Override public Entry<K, V> pollLastEntry() {
226     return delegate.pollLastEntry();
227   }
228 
229   @Override public V put(K key, V value) {
230     return delegate.put(checkValid(key), value);
231   }
232 
233   @Override public void putAll(Map<? extends K, ? extends V> map) {
234     for (K key : map.keySet()) {
235       checkValid(key);
236     }
237     delegate.putAll(map);
238   }
239 
240   @Override public V remove(Object key) {
241     return delegate.remove(checkValid(key));
242   }
243 
244   @Override public int size() {
245     return delegate.size();
246   }
247 
248   @Override public NavigableMap<K, V> subMap(
249       K fromKey, boolean fromInclusive, K toKey, boolean toInclusive) {
250     return new SafeTreeMap<K, V>(delegate.subMap(
251         checkValid(fromKey), fromInclusive, checkValid(toKey), toInclusive));
252   }
253 
254   @Override public SortedMap<K, V> subMap(K fromKey, K toKey) {
255     return subMap(fromKey, true, toKey, false);
256   }
257 
258   @Override public SortedMap<K, V> tailMap(K fromKey) {
259     return tailMap(fromKey, true);
260   }
261 
262   @Override public NavigableMap<K, V> tailMap(K fromKey, boolean inclusive) {
263     return new SafeTreeMap<K, V>(
264         delegate.tailMap(checkValid(fromKey), inclusive));
265   }
266 
267   @Override public Collection<V> values() {
268     return delegate.values();
269   }
270 
271   private <T> T checkValid(T t) {
272     // a ClassCastException is what's supposed to happen!
273     @SuppressWarnings("unchecked")
274     K k = (K) t;
275     comparator().compare(k, k);
276     return t;
277   }
278 
279   @Override public boolean equals(Object obj) {
280     return delegate.equals(obj);
281   }
282 
283   @Override public int hashCode() {
284     return delegate.hashCode();
285   }
286 
287   @Override public String toString() {
288     return delegate.toString();
289   }
290 
291   private static final long serialVersionUID = 0L;
292 }
293