1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ 18 19 #include "tensorflow/core/kernels/conditional_accumulator_base.h" 20 21 namespace tensorflow { 22 23 /* 24 * TypedConditionalAccumulatorBase is a templated companion of 25 * ConditionalAccumulatorBase which allows for subclasses to use different 26 * types for the input gradients. (See ConditionalAccumulator and 27 * SparseConditionalAccumulator.) 28 * 29 * TypedConditionalAccumulatorBase defines virtual methods and implements 30 * methods which depend on the gradient type. These are mainly methods that are 31 * used for adding a new gradient to the accumulator. 32 */ 33 template <typename GradientTensorType> 34 class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { 35 public: TypedConditionalAccumulatorBase(const DataType & dtype,const PartialTensorShape & shape,const string & name,const string & reduction_type)36 TypedConditionalAccumulatorBase(const DataType& dtype, 37 const PartialTensorShape& shape, 38 const string& name, 39 const string& reduction_type) 40 : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} 41 42 /** 43 * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is 44 * successful (i.e., has its gradient applied) if its local_step >= 45 * current_global_step_ at the time the attempt is processed. Otherwise, if 46 * local_step < current_global_step_, the stale gradient is silently dropped. 47 * 48 * local_step: Time-step at which the gradient was computed. 49 * grad: Gradient tensor to be added to the accumulator. 50 * ctx: Context in which the op is executed. 51 */ TryApplyGrad(int64 local_step,OpKernelContext * ctx)52 void TryApplyGrad(int64 local_step, OpKernelContext* ctx) override { 53 { 54 mutex_lock l(mu_); 55 if (local_step >= current_global_step_) { 56 GradientTensorType* grad = nullptr; 57 bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad); 58 if (is_valid) { 59 if (counter_ > 0) { 60 AddToAccumGradFunction(ctx, grad); 61 } else { 62 AllocateAndAssignToAccumGradFunction(ctx, grad); 63 } 64 counter_++; 65 } 66 CleanUpGradTensor(grad); 67 } 68 } 69 FlushUnlocked(); 70 } 71 72 protected: 73 // Virtual methods to be implemented by sub-classes for different datatypes. 74 // Implements arithmetic operations specific to datatype. 75 virtual void AllocateAndAssignToAccumGradFunction( 76 OpKernelContext* ctx, GradientTensorType* grad) = 0; 77 78 virtual void AddToAccumGradFunction(OpKernelContext* ctx, 79 GradientTensorType* grad) = 0; 80 81 // Method for extracting and validating input provided in an OpKernelContext. 82 // Returns true if input was successfully retrieved and is valid. 83 // Gradient is returned via the GradientTensorType** tensor. 84 virtual bool GetAndValidateTensorInputForApplyGrad( 85 OpKernelContext* ctx, GradientTensorType** tensor) 86 EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; 87 88 // Method for cleaning up any memory allocated in 89 // GetAndValidateTensorInputForApplyGrad 90 virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0; 91 }; 92 93 } // namespace tensorflow 94 95 #endif // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ 96