1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
17 #define TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
18 
19 #include <functional>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/graph/costmodel.h"
27 #include "tensorflow/core/graph/graph.h"
28 
29 namespace tensorflow {
30 
31 struct PartitionOptions {
32   // A function that returns a location for the execution of a given
33   // Node.
34   typedef std::function<string(const Node*)> NodeToLocFunc;
35   NodeToLocFunc node_to_loc = nullptr;
36 
37   // A function that returns a unique graph node name with the given
38   // prefix.
39   typedef std::function<string(const string&)> NewNameFunc;
40   NewNameFunc new_name = nullptr;
41 
42   // A function that returns the incarnation of a device given the
43   // device's fullname. If not found, GetIncarnationFunc should return
44   // kIllegalIncarnation.
45   static constexpr uint64 kIllegalIncarnation = 0;
46   typedef std::function<uint64(const string&)> GetIncarnationFunc;
47   GetIncarnationFunc get_incarnation = nullptr;
48 
49   // If specified, flib_def defines a function library that should be
50   // partitioned and replicated into each resulting partition graphs.
51   const FunctionLibraryDefinition* flib_def = nullptr;
52 
53   // True if all the control flow "code" has already been added. The
54   // control flow code needs to be added when we still have the entire
55   // graph before any partitioning. So this flag should be false for
56   // the first partitioning but true for all subsequent partitioning.
57   //
58   // TODO(yuanbyu): We could also make the addition of the control
59   // flow code incremental based on 'node_to_loc'. This makes the
60   // communication a broadcast tree, which could be more efficient when
61   // the number of participating devices is large.
62   bool control_flow_added = false;
63 
64   // A function that returns the data type into which the tensor
65   // should be cast before sent over the wire.
66   typedef std::function<DataType(const Edge*)> ShouldCastFunc;
67   ShouldCastFunc should_cast = nullptr;
68 
69   // Schedule the execution of the recvs based on their start times
70   // computed by some scheduling algorithm. The recvs are divided into
71   // epochs based on their start times. A recv is enabled only when
72   // execution reaches its epoch - N for some predefined N.
73   bool scheduling_for_recvs = false;
74   // The start time for each node in the graph computed by some scheduling
75   // algorithm. If 'need_to_record_start_times' is true, we record them
76   // in the graph as a node attribute.
77   bool need_to_record_start_times = false;
78   std::vector<Microseconds> start_times;
79 };
80 
81 // Partition "input" graph into a set of graphs, one per location.
82 // The location for node n is derived by calling opts.node_to_loc(n).
83 // New nodes added by Partition use "opts.new_name(old_name)" to
84 // generate node names.
85 //
86 // Stores the partitions in *partitions.
87 Status Partition(const PartitionOptions& opts, Graph* input,
88                  std::unordered_map<string, GraphDef>* partitions);
89 
90 // Add control edges to the partitions to control the ordering
91 // and timing of the recv nodes based on the start times calculated
92 // using some scheduling algorithm.
93 Status AddControlEdges(const PartitionOptions& opts,
94                        std::unordered_map<string, GraphDef>* partitions);
95 
96 }  // namespace tensorflow
97 
98 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
99