1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_
17 
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device_set.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 
27 // Given a `device_set` and a `graph`, partitions the `graph` into
28 // `subgraphs`. `subgraphs` maps device names to the graph assigned to that
29 // device. `graph` must have been placed (e.g. by running Placer),
30 // i.e. all nodes must have an assigned_device set.
31 // `graph` is non-const because the underlying Partition() function transforms
32 // the graph to correctly partition distributed control flow.
33 Status PartitionFunctionGraph(
34     const DeviceSet& device_set, std::unique_ptr<Graph> graph,
35     std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs);
36 
37 // Each subgraph produced by partitioning the function body contains a subset
38 // of the original `Arg` and `Retval` nodes. This function performs
39 // bookkeeping to track which `Arg` and `Retval` nodes were placed on a
40 // particular device / subgraph.
41 //
42 // More specifically, this function
43 //  (1) rewrites the indices of the `Arg` and `Retval` nodes placed
44 //      on a particular device.  When a function is parittioned each
45 //      partition, `subgraph`, get a subset of the arguments and
46 //      return values. The `index` attributes of these _Arg and _Retval
47 //      nodes reflect the indices of these parameters in the original
48 //      function. To convert `subgraph` to a function, we need to replace
49 //      there original indices with 0, 1, 2, ... .
50 //
51 //      The argument and return value order in the partitioned function is
52 //      determined by the node iteration order in `subgraph`. This order
53 //      is also used in UpdateArgAndRetvalMetadata. This is fine because the
54 //      node iteration order is deterministic - it follows the node ids.
55 //  (2) records the subsets of `Arg` and `Retval` nodes assigned to the
56 //      device in `*_indices`, and
57 //  (3) records which `Arg` and `Retval` nodes live in host memory in
58 //      `*_alloc_attrs`.
59 Status UpdateArgAndRetvalMetadata(
60     Graph* subgraph, std::vector<int>* arg_indices,
61     std::vector<int>* ret_indices,
62     std::vector<AllocatorAttributes>* arg_alloc_attrs,
63     std::vector<AllocatorAttributes>* ret_alloc_attrs);
64 
65 // Extracts tensors at `indices` from `arguments`.
66 std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
67                                       gtl::ArraySlice<Tensor> arguments);
68 
69 // Utility for generating function names not present in `flib_def`, using
70 // given `name` as the base for the name.
71 class FunctionNameGenerator {
72  public:
73   // `flib_def` must outlive this.
FunctionNameGenerator(const FunctionLibraryDefinition * flib_def,const string & name)74   FunctionNameGenerator(const FunctionLibraryDefinition* flib_def,
75                         const string& name)
76       : flib_def_(flib_def), name_(name), counter_(0) {}
77 
78   // Returns a function name not present in `flib_def` using `name` as
79   // the base and appending a numeric suffix.
80   string GetName();
81 
82  private:
83   const FunctionLibraryDefinition* flib_def_;
84   const string name_;
85   uint32 counter_;
86 };
87 
88 }  // namespace tensorflow
89 
90 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_
91