1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
17 
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/test_helpers.h"
20 #include "tensorflow/stream_executor/lib/statusor.h"
21 
22 namespace xla {
23 namespace {
24 
25 using ::testing::ElementsAre;
26 using ::testing::Property;
27 using ::testing::StrEq;
28 
29 class TestEnv : public tensorflow::EnvWrapper {
30  public:
TestEnv()31   TestEnv() : EnvWrapper(Env::Default()) {}
32 
NowMicros() const33   uint64 NowMicros() const override { return current_micros_; }
34 
SetCurrentMicros(uint64 micros)35   void SetCurrentMicros(uint64 micros) { current_micros_ = micros; }
36 
37  private:
38   uint64 current_micros_ = 1;
39 };
40 
TEST(HloModuleMetadata,RecordsPassStart)41 TEST(HloModuleMetadata, RecordsPassStart) {
42   TestEnv env;
43   HloModuleMetadata module_metadata(&env);
44   env.SetCurrentMicros(1234);
45   module_metadata.RecordPassStart();
46   EXPECT_THAT(
47       module_metadata.proto().pass_metadata(),
48       ElementsAre(Property(&HloPassMetadata::start_timestamp_usec, 1234)));
49 }
50 
TEST(HloModuleMetadata,RecordsPassEnd)51 TEST(HloModuleMetadata, RecordsPassEnd) {
52   TestEnv env;
53   HloModuleMetadata module_metadata(&env);
54   module_metadata.RecordPassStart();
55   env.SetCurrentMicros(4321);
56   EXPECT_IS_OK(module_metadata.RecordPassEnd());
57   EXPECT_THAT(
58       module_metadata.proto().pass_metadata(),
59       ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 4321)));
60 }
61 
TEST(HloModuleMetadata,RecordsPassEndInNestedMetadata)62 TEST(HloModuleMetadata, RecordsPassEndInNestedMetadata) {
63   TestEnv env;
64   HloModuleMetadata module_metadata(&env);
65   module_metadata.RecordPassStart();
66   module_metadata.RecordPassStart();
67   env.SetCurrentMicros(111);
68   EXPECT_IS_OK(module_metadata.RecordPassEnd());
69   EXPECT_THAT(module_metadata.proto().pass_metadata(),
70               ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 0),
71                           Property(&HloPassMetadata::end_timestamp_usec, 111)));
72 
73   env.SetCurrentMicros(222);
74   EXPECT_IS_OK(module_metadata.RecordPassEnd());
75   EXPECT_THAT(module_metadata.proto().pass_metadata(),
76               ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 222),
77                           Property(&HloPassMetadata::end_timestamp_usec, 111)));
78 }
79 
TEST(HloModuleMetadata,RecordPassEndReturnsNotFound)80 TEST(HloModuleMetadata, RecordPassEndReturnsNotFound) {
81   HloModuleMetadata module_metadata(tensorflow::Env::Default());
82   EXPECT_EQ(module_metadata.RecordPassEnd().code(),
83             tensorflow::error::NOT_FOUND);
84 
85   module_metadata.RecordPassStart();
86   EXPECT_IS_OK(module_metadata.RecordPassEnd());
87   EXPECT_EQ(module_metadata.RecordPassEnd().code(),
88             tensorflow::error::NOT_FOUND);
89 }
90 
TEST(HloModuleMetadata,SetsHloPassMetadataFields)91 TEST(HloModuleMetadata, SetsHloPassMetadataFields) {
92   HloModuleMetadata module_metadata(tensorflow::Env::Default());
93   module_metadata.RecordPassStart();
94   EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
95   EXPECT_THAT(
96       module_metadata.proto().pass_metadata(),
97       ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
98 }
99 
TEST(HloModuleMetadata,SetsHloPassMetadataFieldsInNestedMetadata)100 TEST(HloModuleMetadata, SetsHloPassMetadataFieldsInNestedMetadata) {
101   HloModuleMetadata module_metadata(tensorflow::Env::Default());
102   module_metadata.RecordPassStart();
103   module_metadata.RecordPassStart();
104   EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
105   EXPECT_THAT(
106       module_metadata.proto().pass_metadata(),
107       ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("")),
108                   Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
109 }
110 
TEST(HloModuleMetadata,SetterReturnsNotFound)111 TEST(HloModuleMetadata, SetterReturnsNotFound) {
112   HloModuleMetadata module_metadata(tensorflow::Env::Default());
113   EXPECT_EQ(module_metadata.set_current_pass_name("fake name").code(),
114             tensorflow::error::NOT_FOUND);
115 }
116 
TEST(HloModuleMetadata,CopiesRunningPrepartitioningPasses)117 TEST(HloModuleMetadata, CopiesRunningPrepartitioningPasses) {
118   HloModuleMetadata old_module_metadata(tensorflow::Env::Default());
119   old_module_metadata.RecordPassStart();
120   EXPECT_IS_OK(old_module_metadata.set_current_pass_name("outer pass"));
121 
122   old_module_metadata.RecordPassStart();
123   EXPECT_IS_OK(old_module_metadata.set_current_pass_name("finished pass"));
124   EXPECT_IS_OK(old_module_metadata.RecordPassEnd());
125 
126   old_module_metadata.RecordPassStart();
127   EXPECT_IS_OK(old_module_metadata.set_current_pass_name("inner pass"));
128 
129   HloModuleMetadata new_module_metadata(tensorflow::Env::Default());
130   new_module_metadata.set_prepartitioning_metadata(old_module_metadata);
131 
132   // Passes that are still running go in the new module.
133   EXPECT_THAT(
134       new_module_metadata.proto().pass_metadata(),
135       ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("outer pass")),
136                   Property(&HloPassMetadata::pass_name, StrEq("inner pass"))));
137 
138   // Passes that finished go in the prepartitioning metadata.
139   EXPECT_THAT(new_module_metadata.prepartitioning_metadata()->pass_metadata(),
140               ElementsAre(Property(&HloPassMetadata::pass_name,
141                                    StrEq("finished pass"))));
142 }
143 
144 }  // namespace
145 }  // namespace xla
146