1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
16 #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
17 
18 #include "tensorflow/c/eager/c_api.h"
19 
20 #include <algorithm>
21 #include <cstddef>
22 #include <memory>
23 #include <string>
24 #include <thread>
25 #include <vector>
26 
27 #include "tensorflow/c/c_api.h"
28 #include "tensorflow/c/c_api_internal.h"
29 #include "tensorflow/c/eager/runtime.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/lib/gtl/stl_util.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/public/version.h"
39 
40 struct TFE_ContextOptions {
41   TF_SessionOptions session_options;
42   TFE_ContextDevicePlacementPolicy policy{
43       TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
44 };
45 
46 struct TFE_Context {
TFE_ContextTFE_Context47   explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s)
48       : policy(opts.policy),
49         session(s),
50         rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
51         pflr(new tensorflow::ProcessFunctionLibraryRuntime(
52             session->device_mgr, opts.session_options.options.env,
53             TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
54 
55   const TFE_ContextDevicePlacementPolicy policy;
56 
57   // Note: we cannot use C++11 thread_local here as there is no concept of a
58   // thread-local-object-local variable in C++11.
59   tensorflow::mutex policy_map_mu;
60   std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
61       thread_local_policies GUARDED_BY(policy_map_mu);
62 
63   // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
64   TF_Session* const session;
65   tensorflow::Rendezvous* const rendezvous;
66 
67   tensorflow::mutex functions_mu;
GUARDED_BYTFE_Context68   tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
69       tensorflow::OpRegistry::Global(), {}};
70 
71   // One FunctionLibraryRuntime per device.
72   // func_libs[i] is the FunctionLibraryRuntime corresponding to
73   // session->devices[i].
74   const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
75 
76   tensorflow::mutex cache_mu;
77   std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
78                      tensorflow::Fprint128Hasher>
79       kernel_cache GUARDED_BY(cache_mu);
80 
func_libTFE_Context81   tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
82     return pflr->GetFLR(d->name());
83   }
84 
devicesTFE_Context85   const std::vector<tensorflow::Device*>& devices() { return session->devices; }
86 
87   // Whether we should compute RunMetadata.
88   std::atomic<bool> should_store_metadata{false};
89   tensorflow::mutex metadata_mu;
90   tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
91 };
92 
93 struct TFE_TensorHandle {
TFE_TensorHandleTFE_TensorHandle94   TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
95       : t(t), d(d) {}
96 
97   tensorflow::Tensor t;
98   // TODO(ashankar): d == nullptr iff local CPU
99   // This was expedient, but perhaps worth revisiting ('d' should always be a
100   // valid pointer?)
101   // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
102   // provided with the appropriate TFE_Context.
103   //
104   // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
105   // TFE_TensorHandle does not outlive the TFE_Context from which it came?
106   tensorflow::Device* d;
107 };
108 
109 struct TFE_Op {
110   // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
111   // primitive operation.
TFE_OpTFE_Op112   TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
113       : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
114 
is_functionTFE_Op115   bool const is_function() const { return attr_types == nullptr; }
116 
117   TFE_Context* ctx;  // Must outlive the TFE_Op.
118   const tensorflow::string name;
119   tensorflow::AttrBuilder attrs;
120   const tensorflow::AttrTypeMap* attr_types;
121   std::vector<tensorflow::Tensor> inputs;
122   std::vector<tensorflow::Device*> input_devices;
123   tensorflow::Device* device;
124   bool use_xla = false;
125 };
126 
127 #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
128