1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_
17 #define TENSORFLOW_C_C_API_INTERNAL_H_
18 
19 #include "tensorflow/c/c_api.h"
20 
21 #include <list>
22 #include <set>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #ifndef __ANDROID__
28 #include "tensorflow/core/framework/op_gen_lib.h"
29 #endif
30 #include "tensorflow/core/common_runtime/shape_refiner.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/graph/graph_constructor.h"
35 #include "tensorflow/core/graph/node_builder.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/public/session.h"
40 
41 namespace tensorflow {
42 class Device;
43 class DeviceMgr;
44 }  // namespace tensorflow
45 
46 // Internal structures used by the C API. These are likely to change and should
47 // not be depended on.
48 
49 struct TF_Status {
50   tensorflow::Status status;
51 };
52 
53 struct TF_Tensor {
54   ~TF_Tensor();
55 
56   TF_DataType dtype;
57   tensorflow::TensorShape shape;
58   tensorflow::TensorBuffer* buffer;
59 };
60 
61 struct TF_SessionOptions {
62   tensorflow::SessionOptions options;
63 };
64 
65 struct TF_DeprecatedSession {
66   tensorflow::Session* session;
67 };
68 
69 struct TF_Library {
70   void* lib_handle;
71   TF_Buffer op_list;
72 };
73 
74 struct TF_Graph {
75   TF_Graph();
76 
77   tensorflow::mutex mu;
78   tensorflow::Graph graph GUARDED_BY(mu);
79 
80   // Runs shape inference.
81   tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
82 
83   // Maps from name of an operation to the Node* in 'graph'.
84   std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
85       GUARDED_BY(mu);
86 
87   // The keys of this map are all the active sessions using this graph.
88   // Each value is the current "runnability" status of the corresponding
89   // session. Under normal conditions all statuses are Status::OK(), but
90   // if some operation is mutated after it was run by a session (this
91   // is detected in RecordMutation function), that session is no longer
92   // safe to run. Its status will contain the error that will be returned
93   // to the user, should she try running this session.
94   //
95   // Sessions are added to this map in TF_NewSession, and removed in
96   // TF_DeleteSession.
97   // TF_Graph may only / must be deleted when
98   //   sessions.size() == 0 && delete_requested == true
99   tensorflow::gtl::FlatMap<TF_Session*, tensorflow::Status> sessions
100       GUARDED_BY(mu);
101   bool delete_requested GUARDED_BY(mu);  // set true by TF_DeleteGraph
102 
103   // Used to link graphs contained in TF_WhileParams to the parent graph that
104   // will eventually contain the full while loop.
105   TF_Graph* parent;
106   TF_Output* parent_inputs;
107 };
108 
109 struct TF_OperationDescription {
TF_OperationDescriptionTF_OperationDescription110   TF_OperationDescription(TF_Graph* g, const char* op_type,
111                           const char* node_name)
112       : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
113 
114   tensorflow::NodeBuilder node_builder;
115   TF_Graph* graph;
116   std::set<tensorflow::string> colocation_constraints;
117 };
118 
119 struct TF_Operation {
120   tensorflow::Node node;
121 };
122 
123 struct TF_Session {
124   TF_Session(tensorflow::Session* s, TF_Graph* g);
125 
126   tensorflow::Session* session;
127   TF_Graph* graph;
128 
129   tensorflow::mutex mu;
130   int last_num_graph_nodes;
131 
132   // NOTE(ashankar): Experimental fields to help keep the
133   // buffers of a TF_Tensor pinned in device memory.
134   const tensorflow::DeviceMgr* device_mgr;   // Owned by session.
135   std::vector<tensorflow::Device*> devices;  // Owned by device_mgr.
136 };
137 
138 struct TF_ImportGraphDefOptions {
139   tensorflow::ImportGraphDefOptions opts;
140 
141   // Backing memory for TensorId fields in opts.
142   // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
143   std::list<tensorflow::string> tensor_id_data;
144 };
145 
146 struct TF_ImportGraphDefResults {
147   std::vector<TF_Output> return_tensors;
148   std::vector<TF_Operation*> return_nodes;
149   std::vector<const char*> missing_unused_key_names;
150   std::vector<int> missing_unused_key_indexes;
151 
152   // Backing memory for missing_unused_key_names values.
153   std::list<tensorflow::string> missing_unused_key_names_data;
154 };
155 
156 struct TF_DeviceList {
157   std::vector<tensorflow::DeviceAttributes> response;
158 };
159 
160 struct TF_Function {
161   tensorflow::FunctionDef fdef;
162 };
163 
164 struct TF_ApiDefMap {
TF_ApiDefMapTF_ApiDefMap165   explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
166       :
167 #ifndef __ANDROID__
168         api_def_map(op_list),
169 #endif
170         update_docs_called(false) {
171   }
172 
173 #ifndef __ANDROID__
174   tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
175 #endif
176   bool update_docs_called GUARDED_BY(lock);
177   tensorflow::mutex lock;
178 };
179 
180 namespace tensorflow {
181 
182 class TensorCApi {
183  public:
Buffer(const Tensor & tensor)184   static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
MakeTensor(TF_DataType type,const TensorShape & shape,TensorBuffer * buf)185   static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
186                            TensorBuffer* buf) {
187     return Tensor(static_cast<DataType>(type), shape, buf);
188   }
189 };
190 
191 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
192 
193 TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
194 
195 Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
196 
197 // Set the shapes and types of the output's handle.
198 //
199 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
200 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
201 // rank is known), then it must be equal to the length of `shapes[i]`; if
202 // `ranks[i] == 1`, then `shapes[i]` may be nullptr.
203 //
204 // TODO(akshayka): Implement a corresponding getter method.
205 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
206                                            int num_shapes_and_types,
207                                            const int64_t** shapes,
208                                            const int* ranks,
209                                            const TF_DataType* types,
210                                            TF_Status* status);
211 
212 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
213                     const char* mutation_type);
214 
215 }  // end namespace tensorflow
216 
217 #endif  // TENSORFLOW_C_C_API_INTERNAL_H_
218