1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 16 #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 17 18 #include "tensorflow/c/eager/c_api.h" 19 20 #include <algorithm> 21 #include <cstddef> 22 #include <memory> 23 #include <string> 24 #include <thread> 25 #include <vector> 26 27 #include "tensorflow/c/c_api.h" 28 #include "tensorflow/c/c_api_internal.h" 29 #include "tensorflow/c/eager/runtime.h" 30 #include "tensorflow/core/common_runtime/device_factory.h" 31 #include "tensorflow/core/common_runtime/function.h" 32 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 33 #include "tensorflow/core/framework/rendezvous.h" 34 #include "tensorflow/core/lib/gtl/map_util.h" 35 #include "tensorflow/core/lib/gtl/stl_util.h" 36 #include "tensorflow/core/platform/mutex.h" 37 #include "tensorflow/core/platform/thread_annotations.h" 38 #include "tensorflow/core/public/version.h" 39 40 struct TFE_ContextOptions { 41 TF_SessionOptions session_options; 42 TFE_ContextDevicePlacementPolicy policy{ 43 TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; 44 }; 45 46 struct TFE_Context { TFE_ContextTFE_Context47 explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) 48 : policy(opts.policy), 49 session(s), 50 rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), 51 pflr(new tensorflow::ProcessFunctionLibraryRuntime( 52 session->device_mgr, opts.session_options.options.env, 53 TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} 54 55 const TFE_ContextDevicePlacementPolicy policy; 56 57 // Note: we cannot use C++11 thread_local here as there is no concept of a 58 // thread-local-object-local variable in C++11. 59 tensorflow::mutex policy_map_mu; 60 std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy> 61 thread_local_policies GUARDED_BY(policy_map_mu); 62 63 // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. 64 TF_Session* const session; 65 tensorflow::Rendezvous* const rendezvous; 66 67 tensorflow::mutex functions_mu; GUARDED_BYTFE_Context68 tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ 69 tensorflow::OpRegistry::Global(), {}}; 70 71 // One FunctionLibraryRuntime per device. 72 // func_libs[i] is the FunctionLibraryRuntime corresponding to 73 // session->devices[i]. 74 const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr; 75 76 tensorflow::mutex cache_mu; 77 std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*, 78 tensorflow::Fprint128Hasher> 79 kernel_cache GUARDED_BY(cache_mu); 80 func_libTFE_Context81 tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { 82 return pflr->GetFLR(d->name()); 83 } 84 devicesTFE_Context85 const std::vector<tensorflow::Device*>& devices() { return session->devices; } 86 87 // Whether we should compute RunMetadata. 88 std::atomic<bool> should_store_metadata{false}; 89 tensorflow::mutex metadata_mu; 90 tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); 91 }; 92 93 struct TFE_TensorHandle { TFE_TensorHandleTFE_TensorHandle94 TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) 95 : t(t), d(d) {} 96 97 tensorflow::Tensor t; 98 // TODO(ashankar): d == nullptr iff local CPU 99 // This was expedient, but perhaps worth revisiting ('d' should always be a 100 // valid pointer?) 101 // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are 102 // provided with the appropriate TFE_Context. 103 // 104 // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a 105 // TFE_TensorHandle does not outlive the TFE_Context from which it came? 106 tensorflow::Device* d; 107 }; 108 109 struct TFE_Op { 110 // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a 111 // primitive operation. TFE_OpTFE_Op112 TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) 113 : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} 114 is_functionTFE_Op115 bool const is_function() const { return attr_types == nullptr; } 116 117 TFE_Context* ctx; // Must outlive the TFE_Op. 118 const tensorflow::string name; 119 tensorflow::AttrBuilder attrs; 120 const tensorflow::AttrTypeMap* attr_types; 121 std::vector<tensorflow::Tensor> inputs; 122 std::vector<tensorflow::Device*> input_devices; 123 tensorflow::Device* device; 124 bool use_xla = false; 125 }; 126 127 #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ 128