1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/build_graph_options.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/graph/costmodel.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 struct SessionOptions;
36 
37 namespace subgraph {
38 struct RewriteGraphMetadata;
39 }
40 
41 struct GraphExecutionStateOptions {
42   const DeviceSet* device_set = nullptr;
43   const SessionOptions* session_options = nullptr;
44   // Unique session identifier. Can be empty.
45   string session_handle;
46   // A map from node name to device name, representing the unchangeable
47   // placement of stateful nodes.
48   std::unordered_map<string, string> stateful_placements;
49 };
50 
51 // A ClientGraph is simply a sub-graph of the full graph as induced by
52 // BuildGraphOptions.
53 struct ClientGraph {
ClientGraphClientGraph54   explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
55                        DataTypeVector feed_types, DataTypeVector fetch_types,
56                        int64 collective_graph_key)
57       : flib_def(std::move(flib)),
58         graph(flib_def.get()),
59         feed_types(std::move(feed_types)),
60         fetch_types(std::move(fetch_types)),
61         collective_graph_key(collective_graph_key) {}
62   // Each client-graph gets its own function library since optimization passes
63   // post rewrite for execution might want to introduce new functions.
64   std::unique_ptr<FunctionLibraryDefinition> flib_def;
65   Graph graph;
66   DataTypeVector feed_types;
67   DataTypeVector fetch_types;
68   int64 collective_graph_key;
69 };
70 
71 // GraphExecutionState is responsible for generating an
72 // executable ClientGraph from the original GraphDef that specifies
73 // the complete graph and from BuildGraphOptions which specifies
74 // input/output nodes.
75 //
76 // An executable Graph differs from a GraphDef by being Placed,
77 // meaning that each Node is assigned to a single Device in the
78 // available set.
79 //
80 // When GraphExecutionState is first constructed it instantiates
81 // a full Graph from the provided GraphDef, and places it, using only
82 // the static device assignments from the GraphDef.  Nodes without are
83 // currently placed in a very naive way.  Since stateful Nodes cannot
84 // be moved after initial placement, it is important that stateful
85 // Nodes get sensible initial device assignments in the graph
86 // definition.
87 //
88 // Subsequently, GraphExecutionState generates a SimpleClientGraph on
89 // demand, which is a sub-graph of the latest placement of the full
90 // Graph.  MasterSession uses such a ClientGraph to execute one or
91 // more similar client requests.
92 //
93 // GraphExecutionState is thread-safe.
94 
95 class GraphExecutionState {
96  public:
97   virtual ~GraphExecutionState();
98 
99   // Creates a new `GraphExecutionState` for the given
100   // `graph_def`, which represents the entire graph for a session.
101   //
102   // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def`
103   // in an undefined state. If it is necessary to use `*graph_def`
104   // after this call, make an explicit copy of the graph before
105   // calling this method.
106   static Status MakeForBaseGraph(
107       GraphDef* graph_def, const GraphExecutionStateOptions& options,
108       std::unique_ptr<GraphExecutionState>* out_state);
109 
110   // Creates a new `GraphExecutionState` and `SimpleClientGraph`
111   // for the subgraph of `original_graph_def` defined by
112   // `subgraph_options`.
113   static Status MakeForPrunedGraph(
114       const FunctionDefLibrary& func_def_lib,
115       const GraphExecutionStateOptions& options,
116       const GraphDef& original_graph_def,
117       const BuildGraphOptions& subgraph_options,
118       std::unique_ptr<GraphExecutionState>* out_state,
119       std::unique_ptr<ClientGraph>* out_client_graph);
120 
121   // Creates a new GraphExecutionState representing the
122   // concatenation of this graph, and the graph defined by
123   // "extension_def". The same name may not be used to define a node
124   // in both this graph and "extension_def".
125   //
126   // If successful, returns OK and the caller takes ownership of "*out".
127   // Otherwise returns an error and does not modify "*out".
128   //
129   // After calling `old_state->Extend()`, `old_state` may no longer be
130   // used.
131   //
132   // NOTE(mrry): This method respects the placement of stateful nodes in
133   // in *this, but currently does not transfer any other placement
134   // or cost model information to the new graph.
135   Status Extend(const GraphDef& extension_def,
136                 std::unique_ptr<GraphExecutionState>* out) const;
137 
138   // Builds a ClientGraph (a sub-graph of the full graph as induced by
139   // the Node set specified in "options").  If successful, returns OK
140   // and the caller takes the ownership of "*out". Otherwise, returns
141   // an error.
142   Status BuildGraph(const BuildGraphOptions& options,
143                     std::unique_ptr<ClientGraph>* out);
144 
145   // The graph returned by BuildGraph may contain only the pruned
146   // graph, whereas some clients may want access to the full graph.
full_graph()147   const Graph* full_graph() { return graph_; }
148 
149   // Returns the node with the given name, or null if it does not exist.
get_node_by_name(const string & name)150   const Node* get_node_by_name(const string& name) const {
151     NodeNameToCostIdMap::const_iterator iter =
152         node_name_to_cost_id_map_.find(name);
153     if (iter != node_name_to_cost_id_map_.end()) {
154       return graph_->FindNodeId(iter->second);
155     } else {
156       return nullptr;
157     }
158   }
159 
160   // Returns a reference to the current graph_def.  Use must
161   // not extend beyond lifetime of GrahExecutionState object.
original_graph_def()162   const GraphDef& original_graph_def() { return original_graph_def_; }
163 
164   // Returns the map of stateful placements as a map of
165   // node name to placement string.
GetStatefulPlacements()166   std::unordered_map<string, string> GetStatefulPlacements() const {
167     return stateful_placements_;
168   }
169 
170  private:
171   GraphExecutionState(GraphDef* graph_def,
172                       const GraphExecutionStateOptions& options);
173 
174   Status InitBaseGraph(const BuildGraphOptions& options);
175 
176   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
177   // is true, such as "params" and "queue" nodes.  Once placed these
178   // nodes can not be moved to a different device.  Maps node names to
179   // device names.
180   std::unordered_map<string, string> stateful_placements_;  // Immutable after
181                                                             // ctor.
182   void SaveStatefulNodes(Graph* graph);
183   void RestoreStatefulNodes(Graph* graph);
184 
185   // Extract the subset of the graph that needs to be run, adding feed/fetch
186   // ops as needed.
187   Status PruneGraph(const BuildGraphOptions& options, Graph* graph,
188                     subgraph::RewriteGraphMetadata* out_rewrite_metadata);
189 
190   Status OptimizeGraph(
191       const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
192       std::unique_ptr<FunctionLibraryDefinition>* optimized_flib);
193 
194   GraphDef original_graph_def_;            // Immutable after ctor.
195   const DeviceSet* device_set_;            // Not owned
196   const SessionOptions* session_options_;  // Not owned
197   // Unique session identifier. Can be empty.
198   string session_handle_;
199 
200   // Map from name to Node for the full graph in placed_.
201   NodeNameToCostIdMap node_name_to_cost_id_map_;
202 
203   // 'flib_def_' is initialized from the initial graph def's library,
204   // and may be updated by a graph optimization pass.
205   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
206 
207   // `rewrite_metadata_` is only set for GraphExecutionState
208   // objects created by `MakeForPrunedGraph()`.
209   std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
210 
211   // The dataflow graph owned by this object.
212   Graph* graph_;
213 
214   TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
215 };
216 
217 }  // namespace tensorflow
218 
219 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
220