1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/conditional_accumulator_base.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 
19 namespace tensorflow {
20 
ConditionalAccumulatorBase(const DataType & dtype,const PartialTensorShape & shape,const string & name,const string & reduction_type)21 ConditionalAccumulatorBase::ConditionalAccumulatorBase(
22     const DataType& dtype, const PartialTensorShape& shape, const string& name,
23     const string& reduction_type)
24     : dtype_(dtype),
25       shape_(shape),
26       name_(name),
27       reduction_type_(reduction_type) {
28   counter_ = 0;
29   current_global_step_ = 0;
30 }
31 
MatchesNodeDef(const NodeDef & node_def)32 Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) {
33   // TODO(xinghao@): implement the checks for the node definition
34   return Status::OK();
35 }
36 
37 /**
38  * Sets the time step of the accumulator to be in line with the global time
39  * step. Logs warning if the accumulator's time step is already larger than the
40  * provided time step.
41  */
SetGlobalStep(int64 new_global_step)42 Status ConditionalAccumulatorBase::SetGlobalStep(int64 new_global_step) {
43   mutex_lock lock(mu_);
44   if (new_global_step < current_global_step_) {
45     LOG(WARNING) << "Attempt to set current_global_step_ to smaller value: "
46                  << "current_global_step_ = " << current_global_step_
47                  << " >= " << new_global_step << " = new_global_step.";
48   }
49   current_global_step_ = new_global_step;
50   return Status::OK();
51 }
52 
53 /**
54  * Logs an attempt to extract the average gradient, and tries to flush all
55  * TakeGrad attempts.
56  * A TakeGrad attempt is blocked until num_required > counter_, i.e.,
57  * sufficient gradients have been accumulated.
58  *
59  * num_required: Number of gradients that needs to be accumulated before the
60  *               attempt is unblocked.
61  * ctx:          Context in which the op is executed.
62  * callback:     A callback to be executed after the attempt has been completed.
63  */
TryTakeGrad(int num_required,OpKernelContext * ctx,DoneCallback callback)64 void ConditionalAccumulatorBase::TryTakeGrad(int num_required,
65                                              OpKernelContext* ctx,
66                                              DoneCallback callback) {
67   if (num_required <= 0) {
68     ctx->CtxFailureWithWarning(errors::InvalidArgument(
69         "Argument num_required must be positive, but was ", num_required));
70     callback();
71   } else {
72     CancellationManager* cm = ctx->cancellation_manager();
73     CancellationToken token = cm->get_cancellation_token();
74     bool already_cancelled;
75     {
76       mutex_lock l(mu_);
77       already_cancelled = !cm->RegisterCallback(
78           token, [this, cm, token]() { Cancel(cm, token); });
79       if (!already_cancelled) {
80         takegrad_attempts_.emplace_back(
81             num_required, callback, ctx, cm, token,
82             [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
83               if (counter_ >= attempt->elements_requested) {
84                 bool successful_take_grad = TakeGradLockedHelper(
85                     attempt->context, attempt->done_callback);
86                 if (successful_take_grad) {
87                   return kComplete;
88                 } else {
89                   // Try again
90                   return kNoProgress;
91                 }
92               } else {
93                 return kNoProgress;
94               }
95             });
96       }
97     }
98     if (!already_cancelled) {
99       FlushUnlocked();
100     } else {
101       ctx->SetStatus(errors::Cancelled("TakeGrad operation was cancelled"));
102       callback();
103     }
104   }
105 }
106 
107 /**
108  * Cancellation callback.
109  */
Cancel(CancellationManager * cancellation_manager,CancellationToken token)110 void ConditionalAccumulatorBase::Cancel(
111     CancellationManager* cancellation_manager, CancellationToken token) {
112   DoneCallback callback = nullptr;
113   {
114     mutex_lock lock(mu_);
115 
116     for (Attempt& attempt : takegrad_attempts_) {
117       if (attempt.cancellation_manager == cancellation_manager &&
118           attempt.cancellation_token == token) {
119         if (!attempt.is_cancelled) {
120           attempt.is_cancelled = true;
121           attempt.context->SetStatus(
122               errors::Cancelled("TakeGrad operation was cancelled"));
123           std::swap(callback, attempt.done_callback);
124         }
125         break;
126       }
127     }
128   }
129   if (callback) {
130     callback();
131     FlushUnlocked();
132   }
133 }
134 
135 /**
136  * Try to flush logged, blocked TakeGrad attempts.
137  */
TryAttemptLocked(std::vector<CleanUp> * clean_up)138 bool ConditionalAccumulatorBase::TryAttemptLocked(
139     std::vector<CleanUp>* clean_up) {
140   bool progress = false;
141   bool done = false;
142   while (!done && !takegrad_attempts_.empty()) {
143     if (takegrad_attempts_.front().is_cancelled) {
144       VLOG(1) << "Skipping cancelled TakeGrad attempt";
145       takegrad_attempts_.pop_front();
146     } else {
147       Attempt* cur_attempt = &takegrad_attempts_.front();
148       switch (cur_attempt->run_callback(cur_attempt)) {
149         case kNoProgress:
150           done = true;
151           break;
152         case kComplete:
153           progress = true;
154           clean_up->emplace_back(std::move(cur_attempt->done_callback),
155                                  cur_attempt->cancellation_token,
156                                  cur_attempt->context->cancellation_manager());
157           takegrad_attempts_.pop_front();
158           break;
159       }
160     }
161   }
162   return progress;
163 }
164 
165 /**
166  * Try to flush logged, blocked TakeGrad attempts.
167  */
FlushUnlocked()168 void ConditionalAccumulatorBase::FlushUnlocked() {
169   std::vector<CleanUp> clean_up;
170   Ref();
171   {
172     mutex_lock lock(mu_);
173     bool changed;
174     do {
175       changed = TryAttemptLocked(&clean_up);
176     } while (changed);
177   }
178   Unref();
179   for (const auto& to_clean : clean_up) {
180     if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
181       // NOTE(mrry): We can safely ignore the return value of
182       // DeregisterCallback because the mutex mu_ ensures that the
183       // cleanup action only executes once.
184       to_clean.cm->DeregisterCallback(to_clean.to_deregister);
185     }
186     to_clean.finished();
187   }
188 }
189 
TakeGradLockedHelper(OpKernelContext * ctx,DoneCallback callback)190 bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
191                                                       DoneCallback callback) {
192   // At this point, the conditional should have been passed
193 
194   // Implicitly increment global_step
195   current_global_step_++;
196 
197   // Average the accumulated gradient
198   if (reduction_type_ == "MEAN") {
199     DivideAccumGradByCounter(ctx);
200   }
201 
202   // Set output for accumulated gradient tensor
203   bool successful_set_output = SetOutput(ctx);
204 
205   // Reset counter
206   if (successful_set_output) counter_ = 0;
207 
208   return successful_set_output;
209 }
210 
211 }  // namespace tensorflow
212