1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
18 
19 #include <deque>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/numeric_op.h"
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 
27 namespace tensorflow {
28 
29 /**
30  * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation
31  * object for adding gradients.
32  * The two main methods of this class are TryApplyGrad and TryTakeGrad.
33  *
34  * TryApplyGrad tries add a gradient to the accumulator. The attempt is
35  * successful if local_step >= global_step, i.e., if the gradient is not stale,
36  * having been computed using up-to-date information. Otherwise, the gradient is
37  * silently dropped.
38  *
39  * TryTakeGrad logs an attempt to read the average gradient. The attempt is
40  * blocked until the number of gradients accumulated (via TryApplyGrad) is equal
41  * or exceeds the number requested by TryTakeGrad.
42  * Once this condition is satisfied, the following actions are taken:
43  * (1) the value of the average gradient is returned
44  * (2) the count of accumulated gradients is reset to 0
45  * (3) the internal global_step value (current_global_step_) is incremented by 1
46  */
47 class ConditionalAccumulatorBase : public ResourceBase {
48  public:
49   // Args:
50   //   dtype: The datatype of the gradients to be accumulated.
51   //   shape: The shape of the accumulated gradients.
52   //   name:  A name to use for the ConditionalAccumulator.
53   ConditionalAccumulatorBase(const DataType& dtype,
54                              const PartialTensorShape& shape,
55                              const string& name, const string& reduction_type);
56 
57   typedef AsyncOpKernel::DoneCallback DoneCallback;
58 
59   virtual void TryApplyGrad(int64 local_step, OpKernelContext* ctx) = 0;
60   void TryTakeGrad(int num_required, OpKernelContext* ctx,
61                    DoneCallback callback);
62 
63   // Accessor methods
num_accumulated()64   uint32 num_accumulated() {
65     mutex_lock lock(mu_);
66     return counter_;
67   }
68 
dtype()69   const DataType& dtype() const { return dtype_; }
70 
DebugString()71   string DebugString() const override { return "A conditional accumulator"; }
72 
73   // SetGlobalStep is a modifier method for current_global_step.
74   // It returns an InvalidArgument error if the new_global_step is less than
75   // current_global_step.
76   Status SetGlobalStep(int64 new_global_step);
77 
78   Status MatchesNodeDef(const NodeDef& node_def);
79 
80  protected:
81   // Virtual methods to be implemented by sub-classes for different datatypes.
82   // Implements arithmetic operations specific to datatype.
83   virtual void DivideAccumGradByCounter(OpKernelContext* ctx)
84       EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
85   virtual bool SetOutput(OpKernelContext* ctx) = 0;
86 
87   enum RunResult { kNoProgress, kComplete };
88 
89   // Helper struct holding information about a TakeGrad attempt
90   struct Attempt;
91   typedef std::function<RunResult(Attempt*)> RunCallback;
92   struct Attempt {
93     int elements_requested;
94     DoneCallback done_callback;  // must be run outside mu_
95     OpKernelContext* context;
96     CancellationManager* cancellation_manager;  // not owned
97     CancellationToken cancellation_token;
98     RunCallback run_callback;  // must be run while holding mu_
99     bool is_cancelled;
100 
AttemptAttempt101     Attempt(int elements_requested, DoneCallback done_callback,
102             OpKernelContext* context, CancellationManager* cancellation_manager,
103             CancellationToken cancellation_token, RunCallback run_callback)
104         : elements_requested(elements_requested),
105           done_callback(std::move(done_callback)),
106           context(context),
107           cancellation_manager(cancellation_manager),
108           cancellation_token(cancellation_token),
109           run_callback(std::move(run_callback)),
110           is_cancelled(false) {}
111   };
112 
113   // Helper struct for deregistration of a cancellation token and executing a
114   // DoneCallback after a TakeGrad attempt is complete.
115   struct CleanUp {
CleanUpCleanUp116     CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
117         : finished(f), to_deregister(ct), cm(cm) {}
118     DoneCallback finished;
119     CancellationToken to_deregister;
120     CancellationManager* cm;
121   };
122 
123   // Fields
124 
125   const DataType dtype_;
126   const PartialTensorShape shape_;
127   const string name_;
128   const string reduction_type_;
129   mutex mu_;
130   int counter_ GUARDED_BY(mu_);
131   int64 current_global_step_ GUARDED_BY(mu_);
132 
133   std::deque<Attempt> takegrad_attempts_ GUARDED_BY(mu_);
134 
135   // Methods
136 
137   // Helper function for creating cancellation callback
138   void Cancel(CancellationManager* cancellation_manager,
139               CancellationToken token);
140 
141   // Helper functions to process TakeGrad attempts.
142   // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad
143   // calls to try to clear the TakeGrad attempts. This in turn calls
144   // TryAttemptLocked, which then executes the RunCallback of the logged
145   // attempts.
146   // Both functions are modeled after core/kernels/queue_base.
147   // Note: ApplyGrad attempts never block -- unlike in a queue with limited
148   //       capacity, we can always add the newest gradient to our accumulator
149   //       (if it is not stale) or drop it silently (if it is stale).
150   void FlushUnlocked();
151   bool TryAttemptLocked(std::vector<CleanUp>* clean_up)
152       EXCLUSIVE_LOCKS_REQUIRED(mu_);
153 
154   // Helper methods
155   //  void DeepCopy(Tensor* dst);
156   bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback)
157       EXCLUSIVE_LOCKS_REQUIRED(mu_);
158 };
159 
160 /*
161  * Modifications to convenience macros defined in core/framework/op_kernel.h.
162  * The below macros return a boolean if the test fails, so that the calling
163  * function can get an indication that a failure has occurred.
164  */
165 #define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS)          \
166   do {                                                 \
167     if (!TF_PREDICT_TRUE(EXP)) {                       \
168       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
169       return false;                                    \
170     }                                                  \
171   } while (0)
172 
173 #define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS)                 \
174   do {                                                      \
175     ::tensorflow::Status _s(STATUS);                        \
176     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
177       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
178       return false;                                         \
179     }                                                       \
180   } while (0)
181 
182 /*
183  * Convenience classes for helping to convert between numeric types.
184  * The specialization for Eigen::half here simplifies specialization of
185  * ConditionalAccumulator classes later.
186  */
187 template <typename T, typename U>
188 class TypeConverter {
189  public:
ConvertUToT(U c)190   static T ConvertUToT(U c) { return c; /* implicit conversion */ }
191 };
192 
193 template <typename U>
194 class TypeConverter<Eigen::half, U> {
195  public:
ConvertUToT(U c)196   static Eigen::half ConvertUToT(U c) {
197     return Eigen::half_impl::float_to_half_rtne(c);
198   }
199 };
200 
201 }  // namespace tensorflow
202 
203 #endif  // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
204