1 /*
2  * Copyright (C) 2010 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
5  * in compliance with the License. You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software distributed under the License
10  * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
11  * or implied. See the License for the specific language governing permissions and limitations under
12  * the License.
13  */
14 
15 package com.google.common.collect;
16 
17 import static com.google.common.base.Preconditions.checkNotNull;
18 import static com.google.common.base.Preconditions.checkState;
19 
20 import com.google.common.base.Equivalence;
21 import com.google.common.base.Function;
22 import com.google.common.collect.MapMaker.RemovalCause;
23 import com.google.common.collect.MapMaker.RemovalListener;
24 
25 import java.io.IOException;
26 import java.io.ObjectInputStream;
27 import java.io.ObjectOutputStream;
28 import java.lang.ref.ReferenceQueue;
29 import java.util.concurrent.ConcurrentMap;
30 import java.util.concurrent.ExecutionException;
31 import java.util.concurrent.atomic.AtomicReferenceArray;
32 
33 import javax.annotation.Nullable;
34 import javax.annotation.concurrent.GuardedBy;
35 
36 /**
37  * Adds computing functionality to {@link MapMakerInternalMap}.
38  *
39  * @author Bob Lee
40  * @author Charles Fry
41  */
42 class ComputingConcurrentHashMap<K, V> extends MapMakerInternalMap<K, V> {
43   final Function<? super K, ? extends V> computingFunction;
44 
45   /**
46    * Creates a new, empty map with the specified strategy, initial capacity, load factor and
47    * concurrency level.
48    */
ComputingConcurrentHashMap(MapMaker builder, Function<? super K, ? extends V> computingFunction)49   ComputingConcurrentHashMap(MapMaker builder,
50       Function<? super K, ? extends V> computingFunction) {
51     super(builder);
52     this.computingFunction = checkNotNull(computingFunction);
53   }
54 
55   @Override
createSegment(int initialCapacity, int maxSegmentSize)56   Segment<K, V> createSegment(int initialCapacity, int maxSegmentSize) {
57     return new ComputingSegment<K, V>(this, initialCapacity, maxSegmentSize);
58   }
59 
60   @Override
segmentFor(int hash)61   ComputingSegment<K, V> segmentFor(int hash) {
62     return (ComputingSegment<K, V>) super.segmentFor(hash);
63   }
64 
getOrCompute(K key)65   V getOrCompute(K key) throws ExecutionException {
66     int hash = hash(checkNotNull(key));
67     return segmentFor(hash).getOrCompute(key, hash, computingFunction);
68   }
69 
70   @SuppressWarnings("serial") // This class is never serialized.
71   static final class ComputingSegment<K, V> extends Segment<K, V> {
ComputingSegment(MapMakerInternalMap<K, V> map, int initialCapacity, int maxSegmentSize)72     ComputingSegment(MapMakerInternalMap<K, V> map, int initialCapacity, int maxSegmentSize) {
73       super(map, initialCapacity, maxSegmentSize);
74     }
75 
getOrCompute(K key, int hash, Function<? super K, ? extends V> computingFunction)76     V getOrCompute(K key, int hash, Function<? super K, ? extends V> computingFunction)
77         throws ExecutionException {
78       try {
79         outer: while (true) {
80           // don't call getLiveEntry, which would ignore computing values
81           ReferenceEntry<K, V> e = getEntry(key, hash);
82           if (e != null) {
83             V value = getLiveValue(e);
84             if (value != null) {
85               recordRead(e);
86               return value;
87             }
88           }
89 
90           // at this point e is either null, computing, or expired;
91           // avoid locking if it's already computing
92           if (e == null || !e.getValueReference().isComputingReference()) {
93             boolean createNewEntry = true;
94             ComputingValueReference<K, V> computingValueReference = null;
95             lock();
96             try {
97               preWriteCleanup();
98 
99               int newCount = this.count - 1;
100               AtomicReferenceArray<ReferenceEntry<K, V>> table = this.table;
101               int index = hash & (table.length() - 1);
102               ReferenceEntry<K, V> first = table.get(index);
103 
104               for (e = first; e != null; e = e.getNext()) {
105                 K entryKey = e.getKey();
106                 if (e.getHash() == hash && entryKey != null
107                     && map.keyEquivalence.equivalent(key, entryKey)) {
108                   ValueReference<K, V> valueReference = e.getValueReference();
109                   if (valueReference.isComputingReference()) {
110                     createNewEntry = false;
111                   } else {
112                     V value = e.getValueReference().get();
113                     if (value == null) {
114                       enqueueNotification(entryKey, hash, value, RemovalCause.COLLECTED);
115                     } else if (map.expires() && map.isExpired(e)) {
116                       // This is a duplicate check, as preWriteCleanup already purged expired
117                       // entries, but let's accomodate an incorrect expiration queue.
118                       enqueueNotification(entryKey, hash, value, RemovalCause.EXPIRED);
119                     } else {
120                       recordLockedRead(e);
121                       return value;
122                     }
123 
124                     // immediately reuse invalid entries
125                     evictionQueue.remove(e);
126                     expirationQueue.remove(e);
127                     this.count = newCount; // write-volatile
128                   }
129                   break;
130                 }
131               }
132 
133               if (createNewEntry) {
134                 computingValueReference = new ComputingValueReference<K, V>(computingFunction);
135 
136                 if (e == null) {
137                   e = newEntry(key, hash, first);
138                   e.setValueReference(computingValueReference);
139                   table.set(index, e);
140                 } else {
141                   e.setValueReference(computingValueReference);
142                 }
143               }
144             } finally {
145               unlock();
146               postWriteCleanup();
147             }
148 
149             if (createNewEntry) {
150               // This thread solely created the entry.
151               return compute(key, hash, e, computingValueReference);
152             }
153           }
154 
155           // The entry already exists. Wait for the computation.
156           checkState(!Thread.holdsLock(e), "Recursive computation");
157           // don't consider expiration as we're concurrent with computation
158           V value = e.getValueReference().waitForValue();
159           if (value != null) {
160             recordRead(e);
161             return value;
162           }
163           // else computing thread will clearValue
164           continue outer;
165         }
166       } finally {
167         postReadCleanup();
168       }
169     }
170 
compute(K key, int hash, ReferenceEntry<K, V> e, ComputingValueReference<K, V> computingValueReference)171     V compute(K key, int hash, ReferenceEntry<K, V> e,
172         ComputingValueReference<K, V> computingValueReference)
173         throws ExecutionException {
174       V value = null;
175       long start = System.nanoTime();
176       long end = 0;
177       try {
178         // Synchronizes on the entry to allow failing fast when a recursive computation is
179         // detected. This is not fool-proof since the entry may be copied when the segment
180         // is written to.
181         synchronized (e) {
182           value = computingValueReference.compute(key, hash);
183           end = System.nanoTime();
184         }
185         if (value != null) {
186           // putIfAbsent
187           V oldValue = put(key, hash, value, true);
188           if (oldValue != null) {
189             // the computed value was already clobbered
190             enqueueNotification(key, hash, value, RemovalCause.REPLACED);
191           }
192         }
193         return value;
194       } finally {
195         if (end == 0) {
196           end = System.nanoTime();
197         }
198         if (value == null) {
199           clearValue(key, hash, computingValueReference);
200         }
201       }
202     }
203   }
204 
205   /**
206    * Used to provide computation exceptions to other threads.
207    */
208   private static final class ComputationExceptionReference<K, V> implements ValueReference<K, V> {
209     final Throwable t;
210 
ComputationExceptionReference(Throwable t)211     ComputationExceptionReference(Throwable t) {
212       this.t = t;
213     }
214 
215     @Override
get()216     public V get() {
217       return null;
218     }
219 
220     @Override
getEntry()221     public ReferenceEntry<K, V> getEntry() {
222       return null;
223     }
224 
225     @Override
copyFor( ReferenceQueue<V> queue, V value, ReferenceEntry<K, V> entry)226     public ValueReference<K, V> copyFor(
227         ReferenceQueue<V> queue, V value, ReferenceEntry<K, V> entry) {
228       return this;
229     }
230 
231     @Override
isComputingReference()232     public boolean isComputingReference() {
233       return false;
234     }
235 
236     @Override
waitForValue()237     public V waitForValue() throws ExecutionException {
238       throw new ExecutionException(t);
239     }
240 
241     @Override
clear(ValueReference<K, V> newValue)242     public void clear(ValueReference<K, V> newValue) {}
243   }
244 
245   /**
246    * Used to provide computation result to other threads.
247    */
248   private static final class ComputedReference<K, V> implements ValueReference<K, V> {
249     final V value;
250 
ComputedReference(@ullable V value)251     ComputedReference(@Nullable V value) {
252       this.value = value;
253     }
254 
255     @Override
get()256     public V get() {
257       return value;
258     }
259 
260     @Override
getEntry()261     public ReferenceEntry<K, V> getEntry() {
262       return null;
263     }
264 
265     @Override
copyFor( ReferenceQueue<V> queue, V value, ReferenceEntry<K, V> entry)266     public ValueReference<K, V> copyFor(
267         ReferenceQueue<V> queue, V value, ReferenceEntry<K, V> entry) {
268       return this;
269     }
270 
271     @Override
isComputingReference()272     public boolean isComputingReference() {
273       return false;
274     }
275 
276     @Override
waitForValue()277     public V waitForValue() {
278       return get();
279     }
280 
281     @Override
clear(ValueReference<K, V> newValue)282     public void clear(ValueReference<K, V> newValue) {}
283   }
284 
285   private static final class ComputingValueReference<K, V> implements ValueReference<K, V> {
286     final Function<? super K, ? extends V> computingFunction;
287 
288     @GuardedBy("ComputingValueReference.this") // writes
289     volatile ValueReference<K, V> computedReference = unset();
290 
ComputingValueReference(Function<? super K, ? extends V> computingFunction)291     public ComputingValueReference(Function<? super K, ? extends V> computingFunction) {
292       this.computingFunction = computingFunction;
293     }
294 
295     @Override
get()296     public V get() {
297       // All computation lookups go through waitForValue. This method thus is
298       // only used by put, to whom we always want to appear absent.
299       return null;
300     }
301 
302     @Override
getEntry()303     public ReferenceEntry<K, V> getEntry() {
304       return null;
305     }
306 
307     @Override
copyFor( ReferenceQueue<V> queue, @Nullable V value, ReferenceEntry<K, V> entry)308     public ValueReference<K, V> copyFor(
309         ReferenceQueue<V> queue, @Nullable V value, ReferenceEntry<K, V> entry) {
310       return this;
311     }
312 
313     @Override
isComputingReference()314     public boolean isComputingReference() {
315       return true;
316     }
317 
318     /**
319      * Waits for a computation to complete. Returns the result of the computation.
320      */
321     @Override
waitForValue()322     public V waitForValue() throws ExecutionException {
323       if (computedReference == UNSET) {
324         boolean interrupted = false;
325         try {
326           synchronized (this) {
327             while (computedReference == UNSET) {
328               try {
329                 wait();
330               } catch (InterruptedException ie) {
331                 interrupted = true;
332               }
333             }
334           }
335         } finally {
336           if (interrupted) {
337             Thread.currentThread().interrupt();
338           }
339         }
340       }
341       return computedReference.waitForValue();
342     }
343 
344     @Override
clear(ValueReference<K, V> newValue)345     public void clear(ValueReference<K, V> newValue) {
346       // The pending computation was clobbered by a manual write. Unblock all
347       // pending gets, and have them return the new value.
348       setValueReference(newValue);
349 
350       // TODO(fry): could also cancel computation if we had a thread handle
351     }
352 
compute(K key, int hash)353     V compute(K key, int hash) throws ExecutionException {
354       V value;
355       try {
356         value = computingFunction.apply(key);
357       } catch (Throwable t) {
358         setValueReference(new ComputationExceptionReference<K, V>(t));
359         throw new ExecutionException(t);
360       }
361 
362       setValueReference(new ComputedReference<K, V>(value));
363       return value;
364     }
365 
setValueReference(ValueReference<K, V> valueReference)366     void setValueReference(ValueReference<K, V> valueReference) {
367       synchronized (this) {
368         if (computedReference == UNSET) {
369           computedReference = valueReference;
370           notifyAll();
371         }
372       }
373     }
374   }
375 
376   // Serialization Support
377 
378   private static final long serialVersionUID = 4;
379 
380   @Override
writeReplace()381   Object writeReplace() {
382     return new ComputingSerializationProxy<K, V>(keyStrength, valueStrength, keyEquivalence,
383         valueEquivalence, expireAfterWriteNanos, expireAfterAccessNanos, maximumSize,
384         concurrencyLevel, removalListener, this, computingFunction);
385   }
386 
387   static final class ComputingSerializationProxy<K, V> extends AbstractSerializationProxy<K, V> {
388 
389     final Function<? super K, ? extends V> computingFunction;
390 
ComputingSerializationProxy(Strength keyStrength, Strength valueStrength, Equivalence<Object> keyEquivalence, Equivalence<Object> valueEquivalence, long expireAfterWriteNanos, long expireAfterAccessNanos, int maximumSize, int concurrencyLevel, RemovalListener<? super K, ? super V> removalListener, ConcurrentMap<K, V> delegate, Function<? super K, ? extends V> computingFunction)391     ComputingSerializationProxy(Strength keyStrength, Strength valueStrength,
392         Equivalence<Object> keyEquivalence, Equivalence<Object> valueEquivalence,
393         long expireAfterWriteNanos, long expireAfterAccessNanos, int maximumSize,
394         int concurrencyLevel, RemovalListener<? super K, ? super V> removalListener,
395         ConcurrentMap<K, V> delegate, Function<? super K, ? extends V> computingFunction) {
396       super(keyStrength, valueStrength, keyEquivalence, valueEquivalence, expireAfterWriteNanos,
397           expireAfterAccessNanos, maximumSize, concurrencyLevel, removalListener, delegate);
398       this.computingFunction = computingFunction;
399     }
400 
writeObject(ObjectOutputStream out)401     private void writeObject(ObjectOutputStream out) throws IOException {
402       out.defaultWriteObject();
403       writeMapTo(out);
404     }
405 
406     @SuppressWarnings("deprecation") // self-use
readObject(ObjectInputStream in)407     private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
408       in.defaultReadObject();
409       MapMaker mapMaker = readMapMaker(in);
410       delegate = mapMaker.makeComputingMap(computingFunction);
411       readEntries(in);
412     }
413 
readResolve()414     Object readResolve() {
415       return delegate;
416     }
417 
418     private static final long serialVersionUID = 4;
419   }
420 }
421