1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
17 #define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
18 
19 #include <functional>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/service/hlo.pb.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/env.h"
26 
27 namespace xla {
28 
29 // Wrapper class for HloModuleMetadataProto to avoid allowing callers to mutate
30 // arbitrary fields. Specifically, callers cannot set timestamps or ids or
31 // set the fields of any pass not currently running.
32 class HloModuleMetadata {
33  public:
HloModuleMetadata(tensorflow::Env * env)34   explicit HloModuleMetadata(tensorflow::Env* env) : env_(env) {}
35 
proto()36   const HloModuleMetadataProto& proto() const { return module_metadata_; }
37 
38   // Creates a new HloPassMetadata. All calls to RecordPassStart should be
39   // matched by a later call to RecordPassEnd.
40   void RecordPassStart();
41 
42   // Marks the currently running pass as finished. Returns NotFound if metadata
43   // for the currently running pass cannot be found.
44   Status RecordPassEnd();
45 
prepartitioning_metadata()46   const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata()
47       const {
48     return prepartitioning_metadata_;
49   }
50   void set_prepartitioning_metadata(
51       const HloModuleMetadata& prepartitioning_metadata);
52 
53   // Setters for HloModuleMetadataProto.
set_module_group_name(const std::string & name)54   void set_module_group_name(const std::string& name) {
55     module_metadata_.set_module_group_name(name);
56   }
set_canonical_module_id(int64 id)57   void set_canonical_module_id(int64 id) {
58     module_metadata_.set_canonical_module_id(id);
59   }
add_partitioned_module_id(int64 id)60   void add_partitioned_module_id(int64 id) {
61     module_metadata_.add_partitioned_module_ids(id);
62   }
63 
current_pass_id()64   StatusOr<int64> current_pass_id() {
65     TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
66                         GetCurrentHloPassMetadata());
67     return pass_metadata->pass_id();
68   }
69 
70   // Setters for the current HloPassMetadata.
set_current_pass_name(const std::string & pass_name)71   Status set_current_pass_name(const std::string& pass_name) {
72     return MutateCurrentHloPassMetadata(
73         [&pass_name](HloPassMetadata* pass_metadata) {
74           pass_metadata->set_pass_name(pass_name);
75         });
76   }
set_current_pass_pipeline_name(const std::string & pipeline_name)77   Status set_current_pass_pipeline_name(const std::string& pipeline_name) {
78     return MutateCurrentHloPassMetadata(
79         [&pipeline_name](HloPassMetadata* pass_metadata) {
80           pass_metadata->set_pipeline_name(pipeline_name);
81         });
82   }
add_current_pass_dump_filename(const std::string & dump_filename)83   Status add_current_pass_dump_filename(const std::string& dump_filename) {
84     return MutateCurrentHloPassMetadata(
85         [&dump_filename](HloPassMetadata* pass_metadata) {
86           pass_metadata->add_dump_filenames(dump_filename);
87         });
88   }
set_current_pass_module_changed(bool module_changed)89   Status set_current_pass_module_changed(bool module_changed) {
90     return MutateCurrentHloPassMetadata(
91         [&module_changed](HloPassMetadata* pass_metadata) {
92           pass_metadata->set_module_changed(module_changed);
93         });
94   }
set_current_pass_module_id(int64 module_id)95   Status set_current_pass_module_id(int64 module_id) {
96     return MutateCurrentHloPassMetadata(
97         [&module_id](HloPassMetadata* pass_metadata) {
98           pass_metadata->set_module_id(module_id);
99         });
100   }
add_current_pass_module_group_module_id(int64 module_id)101   Status add_current_pass_module_group_module_id(int64 module_id) {
102     return MutateCurrentHloPassMetadata(
103         [&module_id](HloPassMetadata* pass_metadata) {
104           pass_metadata->add_module_group_module_ids(module_id);
105         });
106   }
107 
108  private:
109   // Gets mutable metadata for the currently running pass. If passes are nested,
110   // finds the deepest one still running. Returns NotFound if metadata for the
111   // currently running pass cannot be found.
112   StatusOr<HloPassMetadata*> GetCurrentHloPassMetadata();
113 
114   Status MutateCurrentHloPassMetadata(
115       const std::function<void(HloPassMetadata*)>& mutator);
116 
117   HloModuleMetadataProto module_metadata_;
118   tensorflow::Env* env_;
119   int64 next_pass_id_ = 1;
120 
121   // Stack of metadata for passes that are currently running. Size > 1 iff
122   // passes are nested.
123   std::vector<HloPassMetadata*> running_passes_;
124 
125   // Metadata from before the module was partitioned, if applicable.
126   absl::optional<HloModuleMetadataProto> prepartitioning_metadata_ =
127       absl::nullopt;
128 };
129 
130 }  // namespace xla
131 
132 #endif  // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
133