1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ 18 19 #include <list> 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 27 #include "tensorflow/compiler/xla/service/session.pb.h" 28 #include "tensorflow/compiler/xla/service/user_computation.h" 29 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/compiler/xla/xla_data.pb.h" 33 #include "tensorflow/core/platform/macros.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/thread_annotations.h" 36 #include "tensorflow/core/platform/types.h" 37 38 namespace xla { 39 40 // Tracks computations for the XLA service; computations can be registered 41 // with a UserComputation instance and can be resolved from a handle for later 42 // use. 43 // 44 // This class is also capable of serializing/deserializing computations that it 45 // tracks (and to serialize properly you need to serialize all referred-to 46 // computations as well). 47 class ComputationTracker { 48 public: 49 ComputationTracker(); 50 51 // Creates a new UserComputation object and returns the corresponding 52 // ComputationHandle for it. 53 // 54 // Precondition: user_computation is not already present in the map. 55 ComputationHandle NewComputation(const string& computation_name); 56 57 // Restores session data for a computation that has been serialized, and 58 // allocates a new computation handle for it. 59 StatusOr<ComputationHandle> LoadSessionModule( 60 const SessionModule& session_module); 61 62 // Snapshots a computation (referenced by the provided handle) at its latest 63 // version, returning a module where it is the entry, and any referred-to 64 // computations are entrained as "embedded" (non-entry) computations. 65 StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation( 66 const ComputationHandle& computation); 67 68 // Resolves a ComputationHandle to a UserComputation that is present in the 69 // map. 70 StatusOr<UserComputation*> Resolve( 71 const ComputationHandle& computation) const; 72 73 // Builds an HLO module using the specified computation as the entry. The 74 // module will include the entry computation as well as all computations which 75 // are called directly or indirectly from the entry computation via operations 76 // like "map". config is the HLO module configuration to use for the 77 // constructed module. 78 // If include_unreachable_instructions is true, then instructions 79 // which are not reachable from the root are lowered into HloInstructions 80 // including unreachable parameters. This ensures the entry HloComputation has 81 // the same program shape (ProgramShape) as the entry UserComputation. 82 StatusOr<std::unique_ptr<HloModule>> BuildHloModule( 83 const VersionedComputationHandle& entry_handle, 84 const HloModuleConfig& config, 85 bool include_unreachable_instructions = true) const; 86 87 string ToString() const; 88 89 private: 90 // Bumps the next_computation_ number and returns the allocated number wrapped 91 // in a ComputationHandle. 92 ComputationHandle AllocateHandle() 93 EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); 94 95 // Loads a session computation into a UserComputation, registers it, and 96 // returns the computation handle of the registered computation. If old_to_new 97 // is provided, it is used for remapping references to computations present in 98 // session_computation. 99 // 100 // old_to_new will be updated with the mapping from session_computation's old 101 // handle to the returned handle value, and may not be null. 102 StatusOr<ComputationHandle> LoadSessionComputation( 103 const SessionComputation& session_computation, 104 std::map<int64, ComputationHandle>* old_to_new) 105 EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); 106 107 // Internal implementation of Resolve method which requires, but does not 108 // acquire the mutex. 109 StatusOr<UserComputation*> ResolveInternal( 110 const ComputationHandle& computation) const 111 EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); 112 113 // Builds a post order sort of a computation ("entry") and all of its embedded 114 // computations including all transitively embedded computations. An embedded 115 // computation (the callee) will always appear in the sort before the 116 // computation which calls the embedded computation (the caller). Necessarily, 117 // the entry computation is the last element in the sort. visited and 118 // post_order should be empty when calling. post_order contains the post order 119 // sort when the function return. 120 void ComputeComputationPostOrder( 121 const VersionedComputationHandle& versioned_handle, 122 std::set<VersionedComputationHandle>* visited, 123 std::list<VersionedComputationHandle>* post_order) const 124 EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); 125 126 string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); 127 128 // Guards the computation mapping. Marked mutable so that the Resolve method 129 // can remain const; Resolve does't really modify the tracker in any way, but 130 // it has to lock the mutex for safety. 131 mutable tensorflow::mutex computation_mutex_; 132 133 // The next sequence number to assign to a computation, guarded by the same 134 // mutex as the mapping as they'll be mutated at the same time. 135 int64 next_computation_ GUARDED_BY(computation_mutex_); 136 137 // Mapping from ComputationHandle value to the corresponding registered 138 // UserComputation object. 139 std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_ 140 GUARDED_BY(computation_mutex_); 141 142 TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); 143 }; 144 145 } // namespace xla 146 147 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ 148