1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ 18 19 #include <deque> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/numeric_op.h" 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 27 namespace tensorflow { 28 29 /** 30 * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation 31 * object for adding gradients. 32 * The two main methods of this class are TryApplyGrad and TryTakeGrad. 33 * 34 * TryApplyGrad tries add a gradient to the accumulator. The attempt is 35 * successful if local_step >= global_step, i.e., if the gradient is not stale, 36 * having been computed using up-to-date information. Otherwise, the gradient is 37 * silently dropped. 38 * 39 * TryTakeGrad logs an attempt to read the average gradient. The attempt is 40 * blocked until the number of gradients accumulated (via TryApplyGrad) is equal 41 * or exceeds the number requested by TryTakeGrad. 42 * Once this condition is satisfied, the following actions are taken: 43 * (1) the value of the average gradient is returned 44 * (2) the count of accumulated gradients is reset to 0 45 * (3) the internal global_step value (current_global_step_) is incremented by 1 46 */ 47 class ConditionalAccumulatorBase : public ResourceBase { 48 public: 49 // Args: 50 // dtype: The datatype of the gradients to be accumulated. 51 // shape: The shape of the accumulated gradients. 52 // name: A name to use for the ConditionalAccumulator. 53 ConditionalAccumulatorBase(const DataType& dtype, 54 const PartialTensorShape& shape, 55 const string& name, const string& reduction_type); 56 57 typedef AsyncOpKernel::DoneCallback DoneCallback; 58 59 virtual void TryApplyGrad(int64 local_step, OpKernelContext* ctx) = 0; 60 void TryTakeGrad(int num_required, OpKernelContext* ctx, 61 DoneCallback callback); 62 63 // Accessor methods num_accumulated()64 uint32 num_accumulated() { 65 mutex_lock lock(mu_); 66 return counter_; 67 } 68 dtype()69 const DataType& dtype() const { return dtype_; } 70 DebugString()71 string DebugString() const override { return "A conditional accumulator"; } 72 73 // SetGlobalStep is a modifier method for current_global_step. 74 // It returns an InvalidArgument error if the new_global_step is less than 75 // current_global_step. 76 Status SetGlobalStep(int64 new_global_step); 77 78 Status MatchesNodeDef(const NodeDef& node_def); 79 80 protected: 81 // Virtual methods to be implemented by sub-classes for different datatypes. 82 // Implements arithmetic operations specific to datatype. 83 virtual void DivideAccumGradByCounter(OpKernelContext* ctx) 84 EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; 85 virtual bool SetOutput(OpKernelContext* ctx) = 0; 86 87 enum RunResult { kNoProgress, kComplete }; 88 89 // Helper struct holding information about a TakeGrad attempt 90 struct Attempt; 91 typedef std::function<RunResult(Attempt*)> RunCallback; 92 struct Attempt { 93 int elements_requested; 94 DoneCallback done_callback; // must be run outside mu_ 95 OpKernelContext* context; 96 CancellationManager* cancellation_manager; // not owned 97 CancellationToken cancellation_token; 98 RunCallback run_callback; // must be run while holding mu_ 99 bool is_cancelled; 100 AttemptAttempt101 Attempt(int elements_requested, DoneCallback done_callback, 102 OpKernelContext* context, CancellationManager* cancellation_manager, 103 CancellationToken cancellation_token, RunCallback run_callback) 104 : elements_requested(elements_requested), 105 done_callback(std::move(done_callback)), 106 context(context), 107 cancellation_manager(cancellation_manager), 108 cancellation_token(cancellation_token), 109 run_callback(std::move(run_callback)), 110 is_cancelled(false) {} 111 }; 112 113 // Helper struct for deregistration of a cancellation token and executing a 114 // DoneCallback after a TakeGrad attempt is complete. 115 struct CleanUp { CleanUpCleanUp116 CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) 117 : finished(f), to_deregister(ct), cm(cm) {} 118 DoneCallback finished; 119 CancellationToken to_deregister; 120 CancellationManager* cm; 121 }; 122 123 // Fields 124 125 const DataType dtype_; 126 const PartialTensorShape shape_; 127 const string name_; 128 const string reduction_type_; 129 mutex mu_; 130 int counter_ GUARDED_BY(mu_); 131 int64 current_global_step_ GUARDED_BY(mu_); 132 133 std::deque<Attempt> takegrad_attempts_ GUARDED_BY(mu_); 134 135 // Methods 136 137 // Helper function for creating cancellation callback 138 void Cancel(CancellationManager* cancellation_manager, 139 CancellationToken token); 140 141 // Helper functions to process TakeGrad attempts. 142 // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad 143 // calls to try to clear the TakeGrad attempts. This in turn calls 144 // TryAttemptLocked, which then executes the RunCallback of the logged 145 // attempts. 146 // Both functions are modeled after core/kernels/queue_base. 147 // Note: ApplyGrad attempts never block -- unlike in a queue with limited 148 // capacity, we can always add the newest gradient to our accumulator 149 // (if it is not stale) or drop it silently (if it is stale). 150 void FlushUnlocked(); 151 bool TryAttemptLocked(std::vector<CleanUp>* clean_up) 152 EXCLUSIVE_LOCKS_REQUIRED(mu_); 153 154 // Helper methods 155 // void DeepCopy(Tensor* dst); 156 bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback) 157 EXCLUSIVE_LOCKS_REQUIRED(mu_); 158 }; 159 160 /* 161 * Modifications to convenience macros defined in core/framework/op_kernel.h. 162 * The below macros return a boolean if the test fails, so that the calling 163 * function can get an indication that a failure has occurred. 164 */ 165 #define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ 166 do { \ 167 if (!TF_PREDICT_TRUE(EXP)) { \ 168 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ 169 return false; \ 170 } \ 171 } while (0) 172 173 #define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \ 174 do { \ 175 ::tensorflow::Status _s(STATUS); \ 176 if (!TF_PREDICT_TRUE(_s.ok())) { \ 177 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 178 return false; \ 179 } \ 180 } while (0) 181 182 /* 183 * Convenience classes for helping to convert between numeric types. 184 * The specialization for Eigen::half here simplifies specialization of 185 * ConditionalAccumulator classes later. 186 */ 187 template <typename T, typename U> 188 class TypeConverter { 189 public: ConvertUToT(U c)190 static T ConvertUToT(U c) { return c; /* implicit conversion */ } 191 }; 192 193 template <typename U> 194 class TypeConverter<Eigen::half, U> { 195 public: ConvertUToT(U c)196 static Eigen::half ConvertUToT(U c) { 197 return Eigen::half_impl::float_to_half_rtne(c); 198 } 199 }; 200 201 } // namespace tensorflow 202 203 #endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ 204