1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
17 #define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/framework/device_attributes.pb.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 
28 namespace tensorflow {
29 namespace subgraph {
30 
31 // Information about a graph rewritten by `RewriteGraphForExecution()`.
32 struct RewriteGraphMetadata {
33   // The element type of each tensor fed to this subgraph. The order
34   // of types corresponds to the order of tensor names in
35   // `fed_outputs` when calling `RewriteGraphForExecution()`.
36   DataTypeVector feed_types;
37   // The element type of each tensor fetched from this subgraph. The
38   // order of types corresponds to the order of tensor names in
39   // `fetch_outputs` when calling `RewriteGraphForExecution()`.
40   DataTypeVector fetch_types;
41 };
42 
43 // Describes the action to take on a particular tensor endpoint (described by
44 // a "<node_name>:<output_index>" pair) when pruning the graph.
45 //
46 // The `AddNode()` method must be overridden to describe this action. The method
47 // will be invoked once during `RewriteGraphForExecution()` with tensor endpoint
48 // named by `endpoint_name`, and it may either create a single new node, or fail
49 // with an error if the resulting graph would be invalid.
50 class PruneRewrite {
51  public:
52   // `endpoint_name` and `device_info` must outlive this object.
PruneRewrite(const string * endpoint_name,const DeviceAttributes * device_info)53   PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info)
54       : endpoint_name_(endpoint_name), device_info_(device_info) {}
~PruneRewrite()55   virtual ~PruneRewrite() {}
56 
57   // Creates a new node whose output replaces the given `tensor` in graph `g`.
58   // The node will be assigned to the device named in `device_info`.
59   virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor,
60                          Node** out_node) = 0;
61 
62   // Returns the name of the tensor to which this rewrite applies.
endpoint_name()63   const string& endpoint_name() { return *endpoint_name_; }
64 
65  protected:
66   // The device on which the new node will be created.
device_info()67   const DeviceAttributes& device_info() { return *device_info_; }
68 
69  private:
70   const string* const endpoint_name_;          // Not owned.
71   const DeviceAttributes* const device_info_;  // Not owned.
72 };
73 
74 // Rewrite the graph structure of "*g" to deal with feeding node
75 // outputs, fetching node outputs, and only running a subset of the
76 // graph.  "fed_outputs" and "fetch_outputs" are both lists of
77 // output tensor identifiers in the form of
78 // "<name>[:<optional_output_index>]", and "target_nodes_str" is a
79 // lists of target node names in "*g" "g".
80 //
81 // In the resulting graph "*g", output edges in "fed_outputs" have
82 // been redirected to special "_recv" nodes introduced into the graph.
83 // If these fed nodes are not needed in order to compute the effects
84 // of the nodes in "target_node_names" and "fetch_outputs", then these may
85 // be omitted from the graph.
86 //
87 // In the resulting graph "*g", additional "_send" nodes are connected
88 // to every output in "fetch_outputs".  These "_send" nodes are set up
89 // to execute on the device described by device_info.
90 //
91 // On success, returns OK, and sets "*g" to a version of "*g"
92 // that represents the portions of the graph necessary for producing
93 // the output of all nodes listed in "target_node_names" and fetching the
94 // specific node outputs specified in "fetch_outputs".
95 //
96 // On failure, returns the error status. Possible errors include:
97 //    - fed output "node:output_index" does not exist in "*g"
98 //    - fetch output "node:output_index" does not exist in "*g"
99 //    - target node "node" does not exist in "*g"
100 Status RewriteGraphForExecution(
101     Graph* g, const gtl::ArraySlice<string>& fed_outputs,
102     const gtl::ArraySlice<string>& fetch_outputs,
103     const gtl::ArraySlice<string>& target_node_names,
104     const DeviceAttributes& device_info, bool use_function_convention,
105     RewriteGraphMetadata* out_metadata);
106 
107 // A more general version of the above function that supports
108 // customizable rewriting actions for each fed and fetched tensor.
109 Status RewriteGraphForExecution(
110     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
111     const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
112     const gtl::ArraySlice<string>& target_node_names,
113     RewriteGraphMetadata* out_metadata);
114 
115 /////////////////////////////////////////////////////////
116 // Custom rewrite actions for fed and fetched tensors. //
117 /////////////////////////////////////////////////////////
118 
119 // A rewrite action that adds an _Arg node for a fed tensor.
120 class ArgFeedRewrite : public PruneRewrite {
121  public:
ArgFeedRewrite(const string * endpoint_name,const DeviceAttributes * device_info,int32 arg_index)122   ArgFeedRewrite(const string* endpoint_name,
123                  const DeviceAttributes* device_info, int32 arg_index)
124       : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {}
125   Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
126                  Node** out_node) override;
127 
128  private:
129   const int32 arg_index_;
130 };
131 
132 // A rewrite action that adds a client-terminated _Recv node for a fed tensor.
133 class RecvFeedRewrite : public PruneRewrite {
134  public:
135   using PruneRewrite::PruneRewrite;
136   Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
137                  Node** out_node) override;
138 };
139 
140 // A rewrite action that adds a _Retval node for a fetched tensor.
141 class RetvalFetchRewrite : public PruneRewrite {
142  public:
RetvalFetchRewrite(const string * endpoint_name,const DeviceAttributes * device_info,int32 retval_index)143   RetvalFetchRewrite(const string* endpoint_name,
144                      const DeviceAttributes* device_info, int32 retval_index)
145       : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {}
146   Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
147                  Node** out_node) override;
148 
149  private:
150   const int32 retval_index_;
151 };
152 
153 // A rewrite action that adds a client-terminated _Send node for a
154 // fetched tensor.
155 class SendFetchRewrite : public PruneRewrite {
156  public:
157   using PruneRewrite::PruneRewrite;
158   Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
159                  Node** out_node) override;
160 };
161 
162 }  // namespace subgraph
163 }  // namespace tensorflow
164 
165 #endif  // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
166