1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/memory_types.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/framework/memory_types.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace tensorflow {
28 
29 struct Endpoint {
30   int node_id;
31   int output_index;
32 };
33 
34 struct EndpointHash {
operator ()tensorflow::EndpointHash35   uint32 operator()(const Endpoint& x) const {
36     return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
37                   x.output_index);
38   }
39 };
40 
41 struct EndpointEq {
operator ()tensorflow::EndpointEq42   uint32 operator()(const Endpoint& x, const Endpoint& y) const {
43     return (x.node_id == y.node_id) && (x.output_index == y.output_index);
44   }
45 };
46 
ProcessMemoryTypes(const DeviceType & device_type,const Graph * g,const std::function<Status (const Edge *,MemoryType,MemoryType)> & fn)47 static Status ProcessMemoryTypes(
48     const DeviceType& device_type, const Graph* g,
49     const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
50   if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) {
51     // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
52     // compatible.
53     return Status::OK();
54   }
55   // For GPU and SYCL device, HOST_MEMORY and DEVICE_MEMORY is not
56   // compatible. I.e., a conversion/transfer must be done.
57   //
58   // {node id, slot id} -> memory type.
59   typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
60       MemTypeMap;
61   MemTypeMap inp;
62   MemTypeMap out;
63   MemoryTypeVector inp_mvec;
64   MemoryTypeVector out_mvec;
65   for (const Node* n : g->nodes()) {
66     TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
67                                           n->def(), &inp_mvec, &out_mvec));
68     for (size_t i = 0; i < inp_mvec.size(); ++i) {
69       VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
70       inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
71     }
72     for (size_t i = 0; i < out_mvec.size(); ++i) {
73       VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
74       out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
75     }
76   }
77   for (const Edge* e : g->edges()) {
78     if (e->IsControlEdge()) {
79       continue;
80     }
81     MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
82                                          DEVICE_MEMORY);
83     MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
84                                          DEVICE_MEMORY);
85     VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
86             << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
87             << dm;
88     TF_RETURN_IF_ERROR(fn(e, sm, dm));
89   }
90   return Status::OK();
91 }
92 
ValidateMemoryTypes(const DeviceType & device_type,const Graph * g)93 Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
94   return ProcessMemoryTypes(
95       device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
96         if (sm == dm) {
97           return Status::OK();
98         }
99         return errors::Internal("Memory type mismatch (", sm, " ", dm,
100                                 ") between :", e->src()->id(), ":",
101                                 e->src_output(), " and ", e->dst()->id(), ":",
102                                 e->dst_input(), " : from ",
103                                 FormatNodeForError(*e->src()), " to ",
104                                 FormatNodeForError(*e->dst()));
105       });
106 }
107 
108 // Given an Edge whose two endpoints have different memory types and
109 // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
110 // GetTensorName() returns a unique string that we can use as part of
111 // the rendezvous key. The return string is guaranteed to be unique
112 // within this process. That is sufficient because EnsureMemoryTypes
113 // is only used on a TensorFlow graph that is gonna to be executed in
114 // a single tf device (hence within a single process).
GetTensorName(const Edge * edge)115 static string GetTensorName(const Edge* edge) {
116   static std::atomic<int64> counter(0);
117   return strings::StrCat("memtype_", counter.fetch_add(1), "_",
118                          edge->src()->name());
119 }
120 
Send(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)121 static Node* Send(Graph* g, const string& tensor_name,
122                   const string& device_name, bool host, const Edge* edge) {
123   Node* ret;
124   TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
125                   .Input(edge->src(), edge->src_output())
126                   .Attr("tensor_name", tensor_name)
127                   .Attr("send_device", device_name)
128                   .Attr("send_device_incarnation", 0)  // Do not care.
129                   .Attr("recv_device", device_name)
130                   .Attr("_hostmem_sendrecv", true)
131                   .Finalize(g, &ret));
132   return ret;
133 }
134 
Recv(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)135 static Node* Recv(Graph* g, const string& tensor_name,
136                   const string& device_name, bool host, const Edge* edge) {
137   Node* ret;
138   TF_CHECK_OK(
139       NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
140           .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
141           .Attr("tensor_name", tensor_name)
142           .Attr("send_device", device_name)
143           .Attr("send_device_incarnation", 0)
144           .Attr("recv_device", device_name)
145           .Attr("_hostmem_sendrecv", true)
146           .Finalize(g, &ret));
147   return ret;
148 }
149 
EnsureMemoryTypes(const DeviceType & device_type,const string & device_name,Graph * g)150 Status EnsureMemoryTypes(const DeviceType& device_type,
151                          const string& device_name, Graph* g) {
152   struct Item {
153     const Edge* edge;
154     MemoryType sm;
155     MemoryType dm;
156   };
157   std::vector<Item> edges;
158   TF_RETURN_IF_ERROR(ProcessMemoryTypes(
159       device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
160         if (sm == dm) {
161           return Status::OK();
162         }
163         if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
164             ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
165           edges.push_back({e, sm, dm});
166           return Status::OK();
167         }
168         return errors::Internal("Unexpected memory type pair on an edge: ", sm,
169                                 " vs. ", dm);
170       }));
171 
172   // edges contains edges in 'g' that memtype is not
173   // compatible. Therefore, if we found any, we need to insert
174   // HostSend/Recv and Send/HostRecv pairs.  recv_nodes records all
175   // nodes we added so that we don't copy the same tensor more than
176   // once.
177   if (!edges.empty()) {
178     std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
179     for (const auto& item : edges) {
180       const Edge* e = item.edge;
181       const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
182       Node* recv = nullptr;
183       Endpoint key{e->src()->id(), e->src_output()};
184       auto iter = recv_nodes.find(key);
185       if (iter == recv_nodes.end()) {
186         const string tensor_name = GetTensorName(e);
187         Node* send =
188             Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
189         recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
190         if (!has_ref) {
191           // We only cache if there is no ref is involved.
192           recv_nodes[key] = recv;
193         }
194         g->AddControlEdge(send, recv);
195       } else {
196         recv = iter->second;
197       }
198       g->AddEdge(recv, 0, e->dst(), e->dst_input());
199       g->RemoveEdge(e);
200     }
201   }
202   return ValidateMemoryTypes(device_type, g);
203 }
204 
MemoryTypeForOutput(const DeviceType & device_type,const Graph * g,const Node * n,int index,MemoryType * memory_type)205 Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
206                            const Node* n, int index, MemoryType* memory_type) {
207   MemoryTypeVector inp_mvec;
208   MemoryTypeVector out_mvec;
209   TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
210                                         &inp_mvec, &out_mvec));
211   if (out_mvec.size() <= index) {
212     return errors::Internal("Trying to get the memory type for ", index,
213                             "'th output of node ", FormatNodeForError(*n),
214                             " that has only ", out_mvec.size(), " outputs");
215   }
216   *memory_type = out_mvec[index];
217   return Status::OK();
218 }
219 
220 }  // end namespace tensorflow
221