1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
17
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/test_helpers.h"
20 #include "tensorflow/stream_executor/lib/statusor.h"
21
22 namespace xla {
23 namespace {
24
25 using ::testing::ElementsAre;
26 using ::testing::Property;
27 using ::testing::StrEq;
28
29 class TestEnv : public tensorflow::EnvWrapper {
30 public:
TestEnv()31 TestEnv() : EnvWrapper(Env::Default()) {}
32
NowMicros() const33 uint64 NowMicros() const override { return current_micros_; }
34
SetCurrentMicros(uint64 micros)35 void SetCurrentMicros(uint64 micros) { current_micros_ = micros; }
36
37 private:
38 uint64 current_micros_ = 1;
39 };
40
TEST(HloModuleMetadata,RecordsPassStart)41 TEST(HloModuleMetadata, RecordsPassStart) {
42 TestEnv env;
43 HloModuleMetadata module_metadata(&env);
44 env.SetCurrentMicros(1234);
45 module_metadata.RecordPassStart();
46 EXPECT_THAT(
47 module_metadata.proto().pass_metadata(),
48 ElementsAre(Property(&HloPassMetadata::start_timestamp_usec, 1234)));
49 }
50
TEST(HloModuleMetadata,RecordsPassEnd)51 TEST(HloModuleMetadata, RecordsPassEnd) {
52 TestEnv env;
53 HloModuleMetadata module_metadata(&env);
54 module_metadata.RecordPassStart();
55 env.SetCurrentMicros(4321);
56 EXPECT_IS_OK(module_metadata.RecordPassEnd());
57 EXPECT_THAT(
58 module_metadata.proto().pass_metadata(),
59 ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 4321)));
60 }
61
TEST(HloModuleMetadata,RecordsPassEndInNestedMetadata)62 TEST(HloModuleMetadata, RecordsPassEndInNestedMetadata) {
63 TestEnv env;
64 HloModuleMetadata module_metadata(&env);
65 module_metadata.RecordPassStart();
66 module_metadata.RecordPassStart();
67 env.SetCurrentMicros(111);
68 EXPECT_IS_OK(module_metadata.RecordPassEnd());
69 EXPECT_THAT(module_metadata.proto().pass_metadata(),
70 ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 0),
71 Property(&HloPassMetadata::end_timestamp_usec, 111)));
72
73 env.SetCurrentMicros(222);
74 EXPECT_IS_OK(module_metadata.RecordPassEnd());
75 EXPECT_THAT(module_metadata.proto().pass_metadata(),
76 ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 222),
77 Property(&HloPassMetadata::end_timestamp_usec, 111)));
78 }
79
TEST(HloModuleMetadata,RecordPassEndReturnsNotFound)80 TEST(HloModuleMetadata, RecordPassEndReturnsNotFound) {
81 HloModuleMetadata module_metadata(tensorflow::Env::Default());
82 EXPECT_EQ(module_metadata.RecordPassEnd().code(),
83 tensorflow::error::NOT_FOUND);
84
85 module_metadata.RecordPassStart();
86 EXPECT_IS_OK(module_metadata.RecordPassEnd());
87 EXPECT_EQ(module_metadata.RecordPassEnd().code(),
88 tensorflow::error::NOT_FOUND);
89 }
90
TEST(HloModuleMetadata,SetsHloPassMetadataFields)91 TEST(HloModuleMetadata, SetsHloPassMetadataFields) {
92 HloModuleMetadata module_metadata(tensorflow::Env::Default());
93 module_metadata.RecordPassStart();
94 EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
95 EXPECT_THAT(
96 module_metadata.proto().pass_metadata(),
97 ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
98 }
99
TEST(HloModuleMetadata,SetsHloPassMetadataFieldsInNestedMetadata)100 TEST(HloModuleMetadata, SetsHloPassMetadataFieldsInNestedMetadata) {
101 HloModuleMetadata module_metadata(tensorflow::Env::Default());
102 module_metadata.RecordPassStart();
103 module_metadata.RecordPassStart();
104 EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
105 EXPECT_THAT(
106 module_metadata.proto().pass_metadata(),
107 ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("")),
108 Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
109 }
110
TEST(HloModuleMetadata,SetterReturnsNotFound)111 TEST(HloModuleMetadata, SetterReturnsNotFound) {
112 HloModuleMetadata module_metadata(tensorflow::Env::Default());
113 EXPECT_EQ(module_metadata.set_current_pass_name("fake name").code(),
114 tensorflow::error::NOT_FOUND);
115 }
116
TEST(HloModuleMetadata,CopiesRunningPrepartitioningPasses)117 TEST(HloModuleMetadata, CopiesRunningPrepartitioningPasses) {
118 HloModuleMetadata old_module_metadata(tensorflow::Env::Default());
119 old_module_metadata.RecordPassStart();
120 EXPECT_IS_OK(old_module_metadata.set_current_pass_name("outer pass"));
121
122 old_module_metadata.RecordPassStart();
123 EXPECT_IS_OK(old_module_metadata.set_current_pass_name("finished pass"));
124 EXPECT_IS_OK(old_module_metadata.RecordPassEnd());
125
126 old_module_metadata.RecordPassStart();
127 EXPECT_IS_OK(old_module_metadata.set_current_pass_name("inner pass"));
128
129 HloModuleMetadata new_module_metadata(tensorflow::Env::Default());
130 new_module_metadata.set_prepartitioning_metadata(old_module_metadata);
131
132 // Passes that are still running go in the new module.
133 EXPECT_THAT(
134 new_module_metadata.proto().pass_metadata(),
135 ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("outer pass")),
136 Property(&HloPassMetadata::pass_name, StrEq("inner pass"))));
137
138 // Passes that finished go in the prepartitioning metadata.
139 EXPECT_THAT(new_module_metadata.prepartitioning_metadata()->pass_metadata(),
140 ElementsAre(Property(&HloPassMetadata::pass_name,
141 StrEq("finished pass"))));
142 }
143
144 } // namespace
145 } // namespace xla
146