1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
18 
19 #include "tensorflow/core/kernels/conditional_accumulator_base.h"
20 
21 namespace tensorflow {
22 
23 /*
24  * TypedConditionalAccumulatorBase is a templated companion of
25  * ConditionalAccumulatorBase which allows for subclasses to use different
26  * types for the input gradients. (See ConditionalAccumulator and
27  * SparseConditionalAccumulator.)
28  *
29  * TypedConditionalAccumulatorBase defines virtual methods and implements
30  * methods which depend on the gradient type. These are mainly methods that are
31  * used for adding a new gradient to the accumulator.
32  */
33 template <typename GradientTensorType>
34 class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
35  public:
TypedConditionalAccumulatorBase(const DataType & dtype,const PartialTensorShape & shape,const string & name,const string & reduction_type)36   TypedConditionalAccumulatorBase(const DataType& dtype,
37                                   const PartialTensorShape& shape,
38                                   const string& name,
39                                   const string& reduction_type)
40       : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
41 
42   /**
43    * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
44    * successful (i.e., has its gradient applied) if its local_step >=
45    * current_global_step_ at the time the attempt is processed. Otherwise, if
46    * local_step < current_global_step_, the stale gradient is silently dropped.
47    *
48    * local_step: Time-step at which the gradient was computed.
49    * grad:       Gradient tensor to be added to the accumulator.
50    * ctx:        Context in which the op is executed.
51    */
TryApplyGrad(int64 local_step,OpKernelContext * ctx)52   void TryApplyGrad(int64 local_step, OpKernelContext* ctx) override {
53     {
54       mutex_lock l(mu_);
55       if (local_step >= current_global_step_) {
56         GradientTensorType* grad = nullptr;
57         bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad);
58         if (is_valid) {
59           if (counter_ > 0) {
60             AddToAccumGradFunction(ctx, grad);
61           } else {
62             AllocateAndAssignToAccumGradFunction(ctx, grad);
63           }
64           counter_++;
65         }
66         CleanUpGradTensor(grad);
67       }
68     }
69     FlushUnlocked();
70   }
71 
72  protected:
73   // Virtual methods to be implemented by sub-classes for different datatypes.
74   // Implements arithmetic operations specific to datatype.
75   virtual void AllocateAndAssignToAccumGradFunction(
76       OpKernelContext* ctx, GradientTensorType* grad) = 0;
77 
78   virtual void AddToAccumGradFunction(OpKernelContext* ctx,
79                                       GradientTensorType* grad) = 0;
80 
81   // Method for extracting and validating input provided in an OpKernelContext.
82   // Returns true if input was successfully retrieved and is valid.
83   // Gradient is returned via the GradientTensorType** tensor.
84   virtual bool GetAndValidateTensorInputForApplyGrad(
85       OpKernelContext* ctx, GradientTensorType** tensor)
86       EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
87 
88   // Method for cleaning up any memory allocated in
89   // GetAndValidateTensorInputForApplyGrad
90   virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0;
91 };
92 
93 }  // namespace tensorflow
94 
95 #endif  // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
96