1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.base.Preconditions.checkState;
20 
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.channel.DefaultChannelPromise;
24 import io.netty.util.concurrent.EventExecutor;
25 import java.util.ArrayList;
26 import java.util.List;
27 
28 /**
29  * Promise used when flushing the {@code pendingUnprotectedWrites} queue. It manages the many-to
30  * many relationship between pending unprotected messages and the individual writes. Each protected
31  * frame will be written using the same instance of this promise and it will accumulate the results.
32  * Once all frames have been successfully written (or any failed), all of the promises for the
33  * pending unprotected writes are notified.
34  *
35  * <p>NOTE: this code is based on code in Netty's {@code Http2CodecUtil}.
36  */
37 final class ProtectedPromise extends DefaultChannelPromise {
38   private final List<ChannelPromise> unprotectedPromises;
39   private int expectedCount;
40   private int successfulCount;
41   private int failureCount;
42   private boolean doneAllocating;
43 
ProtectedPromise(Channel channel, EventExecutor executor, int numUnprotectedPromises)44   ProtectedPromise(Channel channel, EventExecutor executor, int numUnprotectedPromises) {
45     super(channel, executor);
46     unprotectedPromises = new ArrayList<>(numUnprotectedPromises);
47   }
48 
49   /**
50    * Adds a promise for a pending unprotected write. This will be notified after all of the writes
51    * complete.
52    */
addUnprotectedPromise(ChannelPromise promise)53   void addUnprotectedPromise(ChannelPromise promise) {
54     unprotectedPromises.add(promise);
55   }
56 
57   /**
58    * Allocate a new promise for the write of a protected frame. This will be used to aggregate the
59    * overall success of the unprotected promises.
60    *
61    * @return {@code this} promise.
62    */
newPromise()63   ChannelPromise newPromise() {
64     checkState(!doneAllocating, "Done allocating. No more promises can be allocated.");
65     expectedCount++;
66     return this;
67   }
68 
69   /**
70    * Signify that no more {@link #newPromise()} allocations will be made. The aggregation can not be
71    * successful until this method is called.
72    *
73    * @return {@code this} promise.
74    */
doneAllocatingPromises()75   ChannelPromise doneAllocatingPromises() {
76     if (!doneAllocating) {
77       doneAllocating = true;
78       if (successfulCount == expectedCount) {
79         trySuccessInternal(null);
80         return super.setSuccess(null);
81       }
82     }
83     return this;
84   }
85 
86   @Override
tryFailure(Throwable cause)87   public boolean tryFailure(Throwable cause) {
88     if (awaitingPromises()) {
89       ++failureCount;
90       if (failureCount == 1) {
91         tryFailureInternal(cause);
92         return super.tryFailure(cause);
93       }
94       // TODO: We break the interface a bit here.
95       // Multiple failure events can be processed without issue because this is an aggregation.
96       return true;
97     }
98     return false;
99   }
100 
101   /**
102    * Fail this object if it has not already been failed.
103    *
104    * <p>This method will NOT throw an {@link IllegalStateException} if called multiple times because
105    * that may be expected.
106    */
107   @Override
setFailure(Throwable cause)108   public ChannelPromise setFailure(Throwable cause) {
109     tryFailure(cause);
110     return this;
111   }
112 
awaitingPromises()113   private boolean awaitingPromises() {
114     return successfulCount + failureCount < expectedCount;
115   }
116 
117   @Override
setSuccess(Void result)118   public ChannelPromise setSuccess(Void result) {
119     trySuccess(result);
120     return this;
121   }
122 
123   @Override
trySuccess(Void result)124   public boolean trySuccess(Void result) {
125     if (awaitingPromises()) {
126       ++successfulCount;
127       if (successfulCount == expectedCount && doneAllocating) {
128         trySuccessInternal(result);
129         return super.trySuccess(result);
130       }
131       // TODO: We break the interface a bit here.
132       // Multiple success events can be processed without issue because this is an aggregation.
133       return true;
134     }
135     return false;
136   }
137 
trySuccessInternal(Void result)138   private void trySuccessInternal(Void result) {
139     for (int i = 0; i < unprotectedPromises.size(); ++i) {
140       unprotectedPromises.get(i).trySuccess(result);
141     }
142   }
143 
tryFailureInternal(Throwable cause)144   private void tryFailureInternal(Throwable cause) {
145     for (int i = 0; i < unprotectedPromises.size(); ++i) {
146       unprotectedPromises.get(i).tryFailure(cause);
147     }
148   }
149 }
150