1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/gradients_util.h"
16
17 #include <memory>
18
19 #include "absl/types/span.h"
20 #include "tensorflow/c/eager/abstract_tensor_handle.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api_unified_experimental.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
24 #include "tensorflow/c/eager/gradients.h"
25 #include "tensorflow/c/eager/gradients_internal.h"
26 #include "tensorflow/c/experimental/ops/array_ops.h"
27 #include "tensorflow/c/experimental/ops/math_ops.h"
28 #include "tensorflow/c/experimental/ops/nn_ops.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/c/tf_tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
33 #include "tensorflow/core/platform/errors.h"
34
35 namespace tensorflow {
36 namespace gradients {
37
38 using namespace std;
39
ScalarTensorHandleHelper(TFE_Context * ctx,float value,TFE_TensorHandle ** result)40 Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
41 TFE_TensorHandle** result) {
42 float data[] = {value};
43 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
44 TF_NewStatus(), TF_DeleteStatus);
45 TF_Tensor* t =
46 TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
47 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
48 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
49 *result = th;
50 TF_DeleteTensor(t);
51 return StatusFromTF_Status(status.get());
52 }
53
TensorHandleWithDimsFloatHelper(TFE_Context * ctx,float data[],int64_t dims[],int num_dims,TFE_TensorHandle ** result)54 Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
55 int64_t dims[], int num_dims,
56 TFE_TensorHandle** result) {
57 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
58 TF_NewStatus(), TF_DeleteStatus);
59 TF_Tensor* t =
60 TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
61 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
62 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
63 *result = th;
64 TF_DeleteTensor(t);
65 return StatusFromTF_Status(status.get());
66 }
67
TensorHandleWithDimsIntHelper(TFE_Context * ctx,int data[],int64_t dims[],int num_dims,TFE_TensorHandle ** result)68 Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
69 int64_t dims[], int num_dims,
70 TFE_TensorHandle** result) {
71 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
72 TF_NewStatus(), TF_DeleteStatus);
73 TF_Tensor* t =
74 TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
75 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
76 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
77 *result = th;
78 TF_DeleteTensor(t);
79 return StatusFromTF_Status(status.get());
80 }
81
82 // Get a scalar TensorHandle with given value
ScalarTensorHandle(AbstractContext * ctx,float value,AbstractTensorHandle ** tensor)83 Status ScalarTensorHandle(AbstractContext* ctx, float value,
84 AbstractTensorHandle** tensor) {
85 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
86 TF_NewStatus(), TF_DeleteStatus);
87 TFE_Context* eager_ctx =
88 TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
89 TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
90 TFE_TensorHandle* input_eager;
91 TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
92 *tensor =
93 unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
94 return StatusFromTF_Status(status.get());
95 }
96
97 // Get a TensorHandle with given float values and dimensions
TensorHandleWithDimsFloat(AbstractContext * ctx,float data[],int64_t dims[],int num_dims,AbstractTensorHandle ** tensor)98 Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
99 int64_t dims[], int num_dims,
100 AbstractTensorHandle** tensor) {
101 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
102 TF_NewStatus(), TF_DeleteStatus);
103 TFE_Context* eager_ctx =
104 TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
105 TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
106 TFE_TensorHandle* input_eager;
107 TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
108 num_dims, &input_eager));
109 *tensor =
110 unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
111 return StatusFromTF_Status(status.get());
112 }
113
114 // Get a TensorHandle with given int values and dimensions
TensorHandleWithDimsInt(AbstractContext * ctx,int data[],int64_t dims[],int num_dims,AbstractTensorHandle ** tensor)115 Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
116 int num_dims, AbstractTensorHandle** tensor) {
117 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
118 TF_NewStatus(), TF_DeleteStatus);
119 TFE_Context* eager_ctx =
120 TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
121 TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
122 TFE_TensorHandle* input_eager;
123 TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
124 num_dims, &input_eager));
125 *tensor =
126 unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
127 return StatusFromTF_Status(status.get());
128 }
129
GetValue(AbstractTensorHandle * t,TF_Tensor ** result_tensor)130 Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
131 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
132 TF_NewStatus(), TF_DeleteStatus);
133 TFE_TensorHandle* result_t =
134 TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
135 TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
136 *result_tensor = TFE_TensorHandleResolve(result_t, status.get());
137 return StatusFromTF_Status(status.get());
138 }
139
GetTensorHandleUtilFloat(AbstractContext * ctx,float vals[],int64_t dims[],int num_dims)140 AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
141 float vals[], int64_t dims[],
142 int num_dims) {
143 AbstractTensorHandlePtr A;
144 AbstractTensorHandle* a_raw = nullptr;
145 Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
146 if (s.ok()) {
147 A.reset(a_raw);
148 }
149 return A;
150 }
151
GetTensorHandleUtilInt(AbstractContext * ctx,int vals[],int64_t dims[],int num_dims)152 AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
153 int64_t dims[], int num_dims) {
154 AbstractTensorHandlePtr A;
155 AbstractTensorHandle* a_raw = nullptr;
156 Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
157 if (s.ok()) {
158 A.reset(a_raw);
159 }
160 return A;
161 }
162
GetScalarTensorHandleUtil(AbstractContext * ctx,float val)163 AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
164 float val) {
165 AbstractTensorHandlePtr y;
166 AbstractTensorHandle* y_raw = nullptr;
167 Status s = ScalarTensorHandle(ctx, val, &y_raw);
168 if (s.ok()) {
169 y.reset(y_raw);
170 }
171 return y;
172 }
173
UpdateWeights(AbstractContext * ctx,vector<AbstractTensorHandle * > & grads,vector<AbstractTensorHandle * > & weights,AbstractTensorHandle * learning_rate)174 Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
175 vector<AbstractTensorHandle*>& weights,
176 AbstractTensorHandle* learning_rate) {
177 /* Update weights one by one using gradient update rule:
178 *
179 * w -= lr*grad[w]
180 *
181 * NOTE: assuming learning rate is positive
182 */
183
184 int num_grads = grads.size();
185 vector<AbstractTensorHandle*> temp_outputs(1);
186 std::string update_str;
187
188 // Negate learning rate for gradient descent
189 TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
190 absl::MakeSpan(temp_outputs),
191 "neg_lr")); // Compute -lr
192 learning_rate = temp_outputs[0];
193
194 for (int i = 0; i < num_grads; i++) {
195 // Compute dW = -lr * grad(w[i])
196 update_str = "update_mul_" + std::to_string(i);
197 TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
198 absl::MakeSpan(temp_outputs),
199 update_str.c_str()));
200
201 AbstractTensorHandle* dW = temp_outputs[0];
202
203 // Compute temp = weights[i] + dW
204 update_str = "update_add_" + std::to_string(i);
205 TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
206 absl::MakeSpan(temp_outputs),
207 update_str.c_str()));
208
209 // Update the weights
210 weights[i] = temp_outputs[0];
211 }
212
213 return Status::OK();
214 }
215
BuildFunction(const char * fn_name)216 AbstractContext* BuildFunction(const char* fn_name) {
217 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
218 TF_NewStatus(), TF_DeleteStatus);
219 TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
220 return unwrap(graph_ctx);
221 }
222
CreateParamsForInputs(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,vector<AbstractTensorHandle * > * params)223 Status CreateParamsForInputs(AbstractContext* ctx,
224 absl::Span<AbstractTensorHandle* const> inputs,
225 vector<AbstractTensorHandle*>* params) {
226 tracing::TracingTensorHandle* handle = nullptr;
227 for (auto input : inputs) {
228 PartialTensorShape shape;
229 TF_RETURN_IF_ERROR(input->Shape(&shape));
230 TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
231 input->DataType(), shape, &handle));
232 params->emplace_back(handle);
233 }
234 return Status::OK();
235 }
236
RunModel(Model model,AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,bool use_function,const GradientRegistry & registry)237 Status RunModel(Model model, AbstractContext* ctx,
238 absl::Span<AbstractTensorHandle* const> inputs,
239 absl::Span<AbstractTensorHandle*> outputs, bool use_function,
240 const GradientRegistry& registry) {
241 if (use_function) {
242 const char* fn_name = "test_fn";
243 std::unique_ptr<AbstractFunction> scoped_func;
244 // Returning null tensors from a tf.function is not supported, so we keep
245 // track of indices in the model's outputs are nullptr in this set.
246 // The FunctionDef only outputs the non-null tensors. We later pad the
247 // function op outputs to have nullptrs at the `null_indices`.
248 absl::flat_hash_set<int> null_indices;
249 {
250 AbstractContextPtr func_ctx(BuildFunction(fn_name));
251 vector<AbstractTensorHandle*> func_inputs;
252 func_inputs.reserve(inputs.size());
253 TF_RETURN_IF_ERROR(
254 CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
255 vector<AbstractTensorHandle*> model_outputs;
256 model_outputs.resize(outputs.size());
257 TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
258 absl::MakeSpan(model_outputs), registry));
259 for (auto func_input : func_inputs) {
260 func_input->Unref();
261 }
262 AbstractFunction* func = nullptr;
263 OutputList output_list;
264 output_list.expected_num_outputs = 0;
265 output_list.outputs.reserve(outputs.size());
266 for (int i = 0; i < model_outputs.size(); i++) {
267 if (model_outputs[i]) {
268 output_list.outputs.emplace_back(model_outputs[i]);
269 output_list.expected_num_outputs += 1;
270 } else {
271 null_indices.insert(i);
272 }
273 }
274 TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
275 ->Finalize(&output_list, &func));
276 scoped_func.reset(func);
277 for (auto output : output_list.outputs) {
278 output->Unref();
279 }
280 TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
281 }
282
283 AbstractOperationPtr fn_op(ctx->CreateOperation());
284 TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
285 for (auto input : inputs) {
286 TF_RETURN_IF_ERROR(fn_op->AddInput(input));
287 }
288 int retvals = outputs.size() - null_indices.size();
289 vector<AbstractTensorHandle*> fn_outputs(retvals);
290 TF_RETURN_IF_ERROR(fn_op->Execute(
291 absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
292 &retvals));
293 int skipped_indices = 0;
294 for (int i = 0; i < outputs.size(); i++) {
295 if (!null_indices.contains(i)) {
296 outputs[i] = fn_outputs[i - skipped_indices];
297 } else {
298 skipped_indices += 1;
299 }
300 }
301 TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
302 return Status::OK();
303 } else {
304 return model(ctx, inputs, outputs, registry);
305 }
306 }
307
BuildImmediateExecutionContext(bool use_tfrt,AbstractContext ** ctx)308 Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
309 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
310 TF_NewStatus(), TF_DeleteStatus);
311 TFE_ContextOptions* opts = TFE_NewContextOptions();
312 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
313 *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
314 TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
315 TFE_DeleteContextOptions(opts);
316 return Status::OK();
317 }
318
319 } // namespace gradients
320 } // namespace tensorflow
321