1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
18 
19 #include <list>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
27 #include "tensorflow/compiler/xla/service/session.pb.h"
28 #include "tensorflow/compiler/xla/service/user_computation.h"
29 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 
40 // Tracks computations for the XLA service; computations can be registered
41 // with a UserComputation instance and can be resolved from a handle for later
42 // use.
43 //
44 // This class is also capable of serializing/deserializing computations that it
45 // tracks (and to serialize properly you need to serialize all referred-to
46 // computations as well).
47 class ComputationTracker {
48  public:
49   ComputationTracker();
50 
51   // Creates a new UserComputation object and returns the corresponding
52   // ComputationHandle for it.
53   //
54   // Precondition: user_computation is not already present in the map.
55   ComputationHandle NewComputation(const string& computation_name);
56 
57   // Restores session data for a computation that has been serialized, and
58   // allocates a new computation handle for it.
59   StatusOr<ComputationHandle> LoadSessionModule(
60       const SessionModule& session_module);
61 
62   // Snapshots a computation (referenced by the provided handle) at its latest
63   // version, returning a module where it is the entry, and any referred-to
64   // computations are entrained as "embedded" (non-entry) computations.
65   StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation(
66       const ComputationHandle& computation);
67 
68   // Resolves a ComputationHandle to a UserComputation that is present in the
69   // map.
70   StatusOr<UserComputation*> Resolve(
71       const ComputationHandle& computation) const;
72 
73   // Builds an HLO module using the specified computation as the entry. The
74   // module will include the entry computation as well as all computations which
75   // are called directly or indirectly from the entry computation via operations
76   // like "map". config is the HLO module configuration to use for the
77   // constructed module.
78   // If include_unreachable_instructions is true, then instructions
79   // which are not reachable from the root are lowered into HloInstructions
80   // including unreachable parameters. This ensures the entry HloComputation has
81   // the same program shape (ProgramShape) as the entry UserComputation.
82   StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
83       const VersionedComputationHandle& entry_handle,
84       const HloModuleConfig& config,
85       bool include_unreachable_instructions = true) const;
86 
87   string ToString() const;
88 
89  private:
90   // Bumps the next_computation_ number and returns the allocated number wrapped
91   // in a ComputationHandle.
92   ComputationHandle AllocateHandle()
93       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
94 
95   // Loads a session computation into a UserComputation, registers it, and
96   // returns the computation handle of the registered computation. If old_to_new
97   // is provided, it is used for remapping references to computations present in
98   // session_computation.
99   //
100   // old_to_new will be updated with the mapping from session_computation's old
101   // handle to the returned handle value, and may not be null.
102   StatusOr<ComputationHandle> LoadSessionComputation(
103       const SessionComputation& session_computation,
104       std::map<int64, ComputationHandle>* old_to_new)
105       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
106 
107   // Internal implementation of Resolve method which requires, but does not
108   // acquire the mutex.
109   StatusOr<UserComputation*> ResolveInternal(
110       const ComputationHandle& computation) const
111       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
112 
113   // Builds a post order sort of a computation ("entry") and all of its embedded
114   // computations including all transitively embedded computations. An embedded
115   // computation (the callee) will always appear in the sort before the
116   // computation which calls the embedded computation (the caller). Necessarily,
117   // the entry computation is the last element in the sort. visited and
118   // post_order should be empty when calling. post_order contains the post order
119   // sort when the function return.
120   void ComputeComputationPostOrder(
121       const VersionedComputationHandle& versioned_handle,
122       std::set<VersionedComputationHandle>* visited,
123       std::list<VersionedComputationHandle>* post_order) const
124       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
125 
126   string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
127 
128   // Guards the computation mapping. Marked mutable so that the Resolve method
129   // can remain const; Resolve does't really modify the tracker in any way, but
130   // it has to lock the mutex for safety.
131   mutable tensorflow::mutex computation_mutex_;
132 
133   // The next sequence number to assign to a computation, guarded by the same
134   // mutex as the mapping as they'll be mutated at the same time.
135   int64 next_computation_ GUARDED_BY(computation_mutex_);
136 
137   // Mapping from ComputationHandle value to the corresponding registered
138   // UserComputation object.
139   std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_
140       GUARDED_BY(computation_mutex_);
141 
142   TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker);
143 };
144 
145 }  // namespace xla
146 
147 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
148