1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/remote_fused_graph_ops.cc.
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
20 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
21 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/macros.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 class RemoteFusedGraphExecuteOp : public OpKernel {
28  public:
RemoteFusedGraphExecuteOp(OpKernelConstruction * const ctx)29   explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
30       : OpKernel(ctx), execute_info_() {
31     string serialized_proto;
32     OP_REQUIRES_OK(
33         ctx, ctx->GetAttr(RemoteFusedGraphExecuteUtils::
34                               ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
35                           &serialized_proto));
36     OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_));
37     OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_));
38     execute_info_.ParseFromString(serialized_proto);
39     if (!execute_info_.executor_name().empty()) {
40       const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func =
41           RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
42               execute_info_.executor_name());
43       if (build_func != nullptr) {
44         TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
45         CHECK(remote_fused_graph_executor_->IsEnabled());
46       } else {
47         LOG(ERROR) << "Executor not found for "
48                    << execute_info_.executor_name();
49       }
50     }
51 
52     if (remote_fused_graph_executor_) {
53       // 1. Initialize remote processor
54       remote_fused_graph_executor_->Init(execute_info_);
55       // Explicitly clear serialized executor parameter after initialization
56       // to release unnecessary memory.
57       execute_info_.clear_serialized_executor_parameters();
58 
59       // 2. Setup graph in remote processor
60       remote_fused_graph_executor_->SetupGraph();
61     }
62   }
63 
~RemoteFusedGraphExecuteOp()64   ~RemoteFusedGraphExecuteOp() final {
65     if (remote_fused_graph_executor_) {
66       // 6. Teardown graph in remote processor
67       remote_fused_graph_executor_->TeardownGraph();
68 
69       // 7. Finalize remote processor
70       remote_fused_graph_executor_->Finalize();
71     }
72   }
73 
Compute(OpKernelContext * const ctx)74   void Compute(OpKernelContext* const ctx) final {
75     CHECK(ctx != nullptr);
76     const int input_count = ctx->num_inputs();
77     const int graph_input_count = execute_info_.graph_input_node_name_size();
78     CHECK(input_count == graph_input_count &&
79           input_count == input_types_.size())
80         << "input_count = " << input_count
81         << ", gt input count = " << execute_info_.graph_input_node_name_size()
82         << ", type count = " << input_types_.size();
83 
84     // 3. Send first data type inputs into remote processor
85     for (int i = 0; i < graph_input_count; ++i) {
86       const Tensor& input_tensor = ctx->input(i);
87       const string& input_node_name = execute_info_.graph_input_node_name(i);
88       if (remote_fused_graph_executor_) {
89         remote_fused_graph_executor_->FillInputNode(input_node_name,
90                                                     input_tensor);
91       }
92     }
93 
94     // 4. Execute graph in remote processor
95     if (remote_fused_graph_executor_) {
96       remote_fused_graph_executor_->ExecuteGraph();
97     }
98 
99     // 5. Load outputs from remote processor
100     const int output_count = ctx->num_outputs();
101     CHECK(output_count == execute_info_.graph_output_node_name_size() &&
102           output_count == output_types_.size());
103     for (int i = 0; i < output_count; ++i) {
104       Tensor* output = nullptr;
105       const string& output_node_name = execute_info_.graph_output_node_name(i);
106       if (remote_fused_graph_executor_) {
107         remote_fused_graph_executor_->ReadOutputNode(
108             output_node_name,
109             [i, &ctx, &output](const TensorShape& shape) -> Tensor* {
110               TF_CHECK_OK(ctx->allocate_output(i, shape, &output));
111               return output;
112             });
113       } else {
114         // For compatibility purpose, returns an empty tensor with specified
115         // data type as output if no executor is used.
116         Tensor* output = nullptr;
117         TensorShape ts({});
118         TF_CHECK_OK(ctx->allocate_output(i, ts, &output));
119       }
120     }
121   }
122 
IsExpensive()123   bool IsExpensive() final { return true; }
124 
125  private:
126   RemoteFusedGraphExecuteInfo execute_info_;
127   std::unique_ptr<IRemoteFusedGraphExecutor> remote_fused_graph_executor_;
128   DataTypeVector input_types_;
129   DataTypeVector output_types_;
130 
131   TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp);
132 };
133 
134 REGISTER_KERNEL_BUILDER(Name("RemoteFusedGraphExecute").Device(DEVICE_CPU),
135                         RemoteFusedGraphExecuteOp);
136 
137 }  // namespace tensorflow
138