1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/gradients_util.h"
16 
17 #include <memory>
18 
19 #include "absl/types/span.h"
20 #include "tensorflow/c/eager/abstract_tensor_handle.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api_unified_experimental.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
24 #include "tensorflow/c/eager/gradients.h"
25 #include "tensorflow/c/eager/gradients_internal.h"
26 #include "tensorflow/c/experimental/ops/array_ops.h"
27 #include "tensorflow/c/experimental/ops/math_ops.h"
28 #include "tensorflow/c/experimental/ops/nn_ops.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/c/tf_tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
33 #include "tensorflow/core/platform/errors.h"
34 
35 namespace tensorflow {
36 namespace gradients {
37 
38 using namespace std;
39 
ScalarTensorHandleHelper(TFE_Context * ctx,float value,TFE_TensorHandle ** result)40 Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
41                                 TFE_TensorHandle** result) {
42   float data[] = {value};
43   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
44       TF_NewStatus(), TF_DeleteStatus);
45   TF_Tensor* t =
46       TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
47   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
48   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
49   *result = th;
50   TF_DeleteTensor(t);
51   return StatusFromTF_Status(status.get());
52 }
53 
TensorHandleWithDimsFloatHelper(TFE_Context * ctx,float data[],int64_t dims[],int num_dims,TFE_TensorHandle ** result)54 Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
55                                        int64_t dims[], int num_dims,
56                                        TFE_TensorHandle** result) {
57   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
58       TF_NewStatus(), TF_DeleteStatus);
59   TF_Tensor* t =
60       TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
61   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
62   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
63   *result = th;
64   TF_DeleteTensor(t);
65   return StatusFromTF_Status(status.get());
66 }
67 
TensorHandleWithDimsIntHelper(TFE_Context * ctx,int data[],int64_t dims[],int num_dims,TFE_TensorHandle ** result)68 Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
69                                      int64_t dims[], int num_dims,
70                                      TFE_TensorHandle** result) {
71   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
72       TF_NewStatus(), TF_DeleteStatus);
73   TF_Tensor* t =
74       TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
75   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
76   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
77   *result = th;
78   TF_DeleteTensor(t);
79   return StatusFromTF_Status(status.get());
80 }
81 
82 // Get a scalar TensorHandle with given value
ScalarTensorHandle(AbstractContext * ctx,float value,AbstractTensorHandle ** tensor)83 Status ScalarTensorHandle(AbstractContext* ctx, float value,
84                           AbstractTensorHandle** tensor) {
85   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
86       TF_NewStatus(), TF_DeleteStatus);
87   TFE_Context* eager_ctx =
88       TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
89   TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
90   TFE_TensorHandle* input_eager;
91   TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
92   *tensor =
93       unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
94   return StatusFromTF_Status(status.get());
95 }
96 
97 // Get a TensorHandle with given float values and dimensions
TensorHandleWithDimsFloat(AbstractContext * ctx,float data[],int64_t dims[],int num_dims,AbstractTensorHandle ** tensor)98 Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
99                                  int64_t dims[], int num_dims,
100                                  AbstractTensorHandle** tensor) {
101   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
102       TF_NewStatus(), TF_DeleteStatus);
103   TFE_Context* eager_ctx =
104       TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
105   TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
106   TFE_TensorHandle* input_eager;
107   TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
108                                                      num_dims, &input_eager));
109   *tensor =
110       unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
111   return StatusFromTF_Status(status.get());
112 }
113 
114 // Get a TensorHandle with given int values and dimensions
TensorHandleWithDimsInt(AbstractContext * ctx,int data[],int64_t dims[],int num_dims,AbstractTensorHandle ** tensor)115 Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
116                                int num_dims, AbstractTensorHandle** tensor) {
117   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
118       TF_NewStatus(), TF_DeleteStatus);
119   TFE_Context* eager_ctx =
120       TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
121   TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
122   TFE_TensorHandle* input_eager;
123   TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
124                                                    num_dims, &input_eager));
125   *tensor =
126       unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
127   return StatusFromTF_Status(status.get());
128 }
129 
GetValue(AbstractTensorHandle * t,TF_Tensor ** result_tensor)130 Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
131   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
132       TF_NewStatus(), TF_DeleteStatus);
133   TFE_TensorHandle* result_t =
134       TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
135   TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
136   *result_tensor = TFE_TensorHandleResolve(result_t, status.get());
137   return StatusFromTF_Status(status.get());
138 }
139 
GetTensorHandleUtilFloat(AbstractContext * ctx,float vals[],int64_t dims[],int num_dims)140 AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
141                                                  float vals[], int64_t dims[],
142                                                  int num_dims) {
143   AbstractTensorHandlePtr A;
144   AbstractTensorHandle* a_raw = nullptr;
145   Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
146   if (s.ok()) {
147     A.reset(a_raw);
148   }
149   return A;
150 }
151 
GetTensorHandleUtilInt(AbstractContext * ctx,int vals[],int64_t dims[],int num_dims)152 AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
153                                                int64_t dims[], int num_dims) {
154   AbstractTensorHandlePtr A;
155   AbstractTensorHandle* a_raw = nullptr;
156   Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
157   if (s.ok()) {
158     A.reset(a_raw);
159   }
160   return A;
161 }
162 
GetScalarTensorHandleUtil(AbstractContext * ctx,float val)163 AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
164                                                   float val) {
165   AbstractTensorHandlePtr y;
166   AbstractTensorHandle* y_raw = nullptr;
167   Status s = ScalarTensorHandle(ctx, val, &y_raw);
168   if (s.ok()) {
169     y.reset(y_raw);
170   }
171   return y;
172 }
173 
UpdateWeights(AbstractContext * ctx,vector<AbstractTensorHandle * > & grads,vector<AbstractTensorHandle * > & weights,AbstractTensorHandle * learning_rate)174 Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
175                      vector<AbstractTensorHandle*>& weights,
176                      AbstractTensorHandle* learning_rate) {
177   /* Update weights one by one using gradient update rule:
178    *
179    *    w -= lr*grad[w]
180    *
181    *  NOTE: assuming learning rate is positive
182    */
183 
184   int num_grads = grads.size();
185   vector<AbstractTensorHandle*> temp_outputs(1);
186   std::string update_str;
187 
188   // Negate learning rate for gradient descent
189   TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
190                               absl::MakeSpan(temp_outputs),
191                               "neg_lr"));  // Compute -lr
192   learning_rate = temp_outputs[0];
193 
194   for (int i = 0; i < num_grads; i++) {
195     // Compute dW = -lr * grad(w[i])
196     update_str = "update_mul_" + std::to_string(i);
197     TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
198                                 absl::MakeSpan(temp_outputs),
199                                 update_str.c_str()));
200 
201     AbstractTensorHandle* dW = temp_outputs[0];
202 
203     // Compute temp = weights[i] + dW
204     update_str = "update_add_" + std::to_string(i);
205     TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
206                                 absl::MakeSpan(temp_outputs),
207                                 update_str.c_str()));
208 
209     // Update the weights
210     weights[i] = temp_outputs[0];
211   }
212 
213   return Status::OK();
214 }
215 
BuildFunction(const char * fn_name)216 AbstractContext* BuildFunction(const char* fn_name) {
217   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
218       TF_NewStatus(), TF_DeleteStatus);
219   TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
220   return unwrap(graph_ctx);
221 }
222 
CreateParamsForInputs(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,vector<AbstractTensorHandle * > * params)223 Status CreateParamsForInputs(AbstractContext* ctx,
224                              absl::Span<AbstractTensorHandle* const> inputs,
225                              vector<AbstractTensorHandle*>* params) {
226   tracing::TracingTensorHandle* handle = nullptr;
227   for (auto input : inputs) {
228     PartialTensorShape shape;
229     TF_RETURN_IF_ERROR(input->Shape(&shape));
230     TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
231         input->DataType(), shape, &handle));
232     params->emplace_back(handle);
233   }
234   return Status::OK();
235 }
236 
RunModel(Model model,AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,bool use_function,const GradientRegistry & registry)237 Status RunModel(Model model, AbstractContext* ctx,
238                 absl::Span<AbstractTensorHandle* const> inputs,
239                 absl::Span<AbstractTensorHandle*> outputs, bool use_function,
240                 const GradientRegistry& registry) {
241   if (use_function) {
242     const char* fn_name = "test_fn";
243     std::unique_ptr<AbstractFunction> scoped_func;
244     // Returning null tensors from a tf.function is not supported, so we keep
245     // track of indices in the model's outputs are nullptr in this set.
246     // The FunctionDef only outputs the non-null tensors. We later pad the
247     // function op outputs to have nullptrs at the `null_indices`.
248     absl::flat_hash_set<int> null_indices;
249     {
250       AbstractContextPtr func_ctx(BuildFunction(fn_name));
251       vector<AbstractTensorHandle*> func_inputs;
252       func_inputs.reserve(inputs.size());
253       TF_RETURN_IF_ERROR(
254           CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
255       vector<AbstractTensorHandle*> model_outputs;
256       model_outputs.resize(outputs.size());
257       TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
258                                absl::MakeSpan(model_outputs), registry));
259       for (auto func_input : func_inputs) {
260         func_input->Unref();
261       }
262       AbstractFunction* func = nullptr;
263       OutputList output_list;
264       output_list.expected_num_outputs = 0;
265       output_list.outputs.reserve(outputs.size());
266       for (int i = 0; i < model_outputs.size(); i++) {
267         if (model_outputs[i]) {
268           output_list.outputs.emplace_back(model_outputs[i]);
269           output_list.expected_num_outputs += 1;
270         } else {
271           null_indices.insert(i);
272         }
273       }
274       TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
275                              ->Finalize(&output_list, &func));
276       scoped_func.reset(func);
277       for (auto output : output_list.outputs) {
278         output->Unref();
279       }
280       TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
281     }
282 
283     AbstractOperationPtr fn_op(ctx->CreateOperation());
284     TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
285     for (auto input : inputs) {
286       TF_RETURN_IF_ERROR(fn_op->AddInput(input));
287     }
288     int retvals = outputs.size() - null_indices.size();
289     vector<AbstractTensorHandle*> fn_outputs(retvals);
290     TF_RETURN_IF_ERROR(fn_op->Execute(
291         absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
292         &retvals));
293     int skipped_indices = 0;
294     for (int i = 0; i < outputs.size(); i++) {
295       if (!null_indices.contains(i)) {
296         outputs[i] = fn_outputs[i - skipped_indices];
297       } else {
298         skipped_indices += 1;
299       }
300     }
301     TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
302     return Status::OK();
303   } else {
304     return model(ctx, inputs, outputs, registry);
305   }
306 }
307 
BuildImmediateExecutionContext(bool use_tfrt,AbstractContext ** ctx)308 Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
309   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
310       TF_NewStatus(), TF_DeleteStatus);
311   TFE_ContextOptions* opts = TFE_NewContextOptions();
312   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
313   *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
314   TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
315   TFE_DeleteContextOptions(opts);
316   return Status::OK();
317 }
318 
319 }  // namespace gradients
320 }  // namespace tensorflow
321