1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/remote_fused_graph_ops.cc. 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" 20 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" 21 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 22 #include "tensorflow/core/lib/core/status.h" 23 #include "tensorflow/core/platform/macros.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 class RemoteFusedGraphExecuteOp : public OpKernel { 28 public: RemoteFusedGraphExecuteOp(OpKernelConstruction * const ctx)29 explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx) 30 : OpKernel(ctx), execute_info_() { 31 string serialized_proto; 32 OP_REQUIRES_OK( 33 ctx, ctx->GetAttr(RemoteFusedGraphExecuteUtils:: 34 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, 35 &serialized_proto)); 36 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_)); 37 OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_)); 38 execute_info_.ParseFromString(serialized_proto); 39 if (!execute_info_.executor_name().empty()) { 40 const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func = 41 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc( 42 execute_info_.executor_name()); 43 if (build_func != nullptr) { 44 TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_)); 45 CHECK(remote_fused_graph_executor_->IsEnabled()); 46 } else { 47 LOG(ERROR) << "Executor not found for " 48 << execute_info_.executor_name(); 49 } 50 } 51 52 if (remote_fused_graph_executor_) { 53 // 1. Initialize remote processor 54 remote_fused_graph_executor_->Init(execute_info_); 55 // Explicitly clear serialized executor parameter after initialization 56 // to release unnecessary memory. 57 execute_info_.clear_serialized_executor_parameters(); 58 59 // 2. Setup graph in remote processor 60 remote_fused_graph_executor_->SetupGraph(); 61 } 62 } 63 ~RemoteFusedGraphExecuteOp()64 ~RemoteFusedGraphExecuteOp() final { 65 if (remote_fused_graph_executor_) { 66 // 6. Teardown graph in remote processor 67 remote_fused_graph_executor_->TeardownGraph(); 68 69 // 7. Finalize remote processor 70 remote_fused_graph_executor_->Finalize(); 71 } 72 } 73 Compute(OpKernelContext * const ctx)74 void Compute(OpKernelContext* const ctx) final { 75 CHECK(ctx != nullptr); 76 const int input_count = ctx->num_inputs(); 77 const int graph_input_count = execute_info_.graph_input_node_name_size(); 78 CHECK(input_count == graph_input_count && 79 input_count == input_types_.size()) 80 << "input_count = " << input_count 81 << ", gt input count = " << execute_info_.graph_input_node_name_size() 82 << ", type count = " << input_types_.size(); 83 84 // 3. Send first data type inputs into remote processor 85 for (int i = 0; i < graph_input_count; ++i) { 86 const Tensor& input_tensor = ctx->input(i); 87 const string& input_node_name = execute_info_.graph_input_node_name(i); 88 if (remote_fused_graph_executor_) { 89 remote_fused_graph_executor_->FillInputNode(input_node_name, 90 input_tensor); 91 } 92 } 93 94 // 4. Execute graph in remote processor 95 if (remote_fused_graph_executor_) { 96 remote_fused_graph_executor_->ExecuteGraph(); 97 } 98 99 // 5. Load outputs from remote processor 100 const int output_count = ctx->num_outputs(); 101 CHECK(output_count == execute_info_.graph_output_node_name_size() && 102 output_count == output_types_.size()); 103 for (int i = 0; i < output_count; ++i) { 104 Tensor* output = nullptr; 105 const string& output_node_name = execute_info_.graph_output_node_name(i); 106 if (remote_fused_graph_executor_) { 107 remote_fused_graph_executor_->ReadOutputNode( 108 output_node_name, 109 [i, &ctx, &output](const TensorShape& shape) -> Tensor* { 110 TF_CHECK_OK(ctx->allocate_output(i, shape, &output)); 111 return output; 112 }); 113 } else { 114 // For compatibility purpose, returns an empty tensor with specified 115 // data type as output if no executor is used. 116 Tensor* output = nullptr; 117 TensorShape ts({}); 118 TF_CHECK_OK(ctx->allocate_output(i, ts, &output)); 119 } 120 } 121 } 122 IsExpensive()123 bool IsExpensive() final { return true; } 124 125 private: 126 RemoteFusedGraphExecuteInfo execute_info_; 127 std::unique_ptr<IRemoteFusedGraphExecutor> remote_fused_graph_executor_; 128 DataTypeVector input_types_; 129 DataTypeVector output_types_; 130 131 TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp); 132 }; 133 134 REGISTER_KERNEL_BUILDER(Name("RemoteFusedGraphExecute").Device(DEVICE_CPU), 135 RemoteFusedGraphExecuteOp); 136 137 } // namespace tensorflow 138