1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Test for parse_flags_from_env.cc
17
18 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/subprocess.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/command_line_flags.h"
33
34 namespace xla {
35
36 // Test that XLA flags can be set from the environment.
37 // Failure messages are accompanied by the text in msg[].
TestParseFlagsFromEnv(const char * msg)38 static void TestParseFlagsFromEnv(const char* msg) {
39 // Initialize module under test.
40 int* pargc;
41 std::vector<char*>* pargv;
42 ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv);
43
44 // Check that actual flags can be parsed.
45 bool simple = false;
46 string with_value;
47 string embedded_quotes;
48 string single_quoted;
49 string double_quoted;
50 std::vector<tensorflow::Flag> flag_list = {
51 tensorflow::Flag("simple", &simple, ""),
52 tensorflow::Flag("with_value", &with_value, ""),
53 tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
54 tensorflow::Flag("single_quoted", &single_quoted, ""),
55 tensorflow::Flag("double_quoted", &double_quoted, ""),
56 };
57 bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
58 CHECK_EQ(*pargc, 1) << msg;
59 const std::vector<char*>& argv_second = *pargv;
60 CHECK_NE(argv_second[0], nullptr) << msg;
61 CHECK_EQ(argv_second[1], nullptr) << msg;
62 CHECK(parsed_ok) << msg;
63 CHECK(simple) << msg;
64 CHECK_EQ(with_value, "a_value") << msg;
65 CHECK_EQ(embedded_quotes, "single'double\"") << msg;
66 CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
67 CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
68 }
69
70 // The flags settings to test.
71 static const char kTestFlagString[] =
72 "--simple "
73 "--with_value=a_value "
74 "--embedded_quotes=single'double\" "
75 "--single_quoted='single quoted \\\\ \n \"' "
76 "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
77
78 // Test that the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,Basic)79 TEST(ParseFlagsFromEnv, Basic) {
80 // Prepare environment.
81 tensorflow::setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
82 TestParseFlagsFromEnv("(flags in environment variable)");
83 }
84
85 // Test that a file named by the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,File)86 TEST(ParseFlagsFromEnv, File) {
87 // environment variables where tmp dir may be specified.
88 static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
89 static const char kTempDir[] = "/tmp"; // default temp dir if all else fails.
90 const char* tmp_dir = nullptr;
91 for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
92 tmp_dir = getenv(kTempVars[i]);
93 }
94 if (tmp_dir == nullptr) {
95 tmp_dir = kTempDir;
96 }
97 string tmp_file =
98 absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
99 FILE* fp = fopen(tmp_file.c_str(), "w");
100 CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
101 for (int i = 0; kTestFlagString[i] != '\0'; i++) {
102 putc(kTestFlagString[i], fp);
103 }
104 fflush(fp);
105 CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
106 fclose(fp);
107 // Prepare environment.
108 tensorflow::setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
109 TestParseFlagsFromEnv("(flags in file)");
110 unlink(tmp_file.c_str());
111 }
112
113 // Name of the test binary.
114 static const char* binary_name;
115
116 // Test that when we use both the environment variable and actual
117 // commend line flags (when the latter is possible), the latter win.
TEST(ParseFlagsFromEnv,EnvAndFlag)118 TEST(ParseFlagsFromEnv, EnvAndFlag) {
119 static struct {
120 const char* env;
121 const char* arg;
122 const char* expected_value;
123 } test[] = {
124 {nullptr, nullptr, "1\n"},
125 {nullptr, "--int_flag=2", "2\n"},
126 {"--int_flag=3", nullptr, "3\n"},
127 {"--int_flag=3", "--int_flag=2", "2\n"}, // flag beats environment
128 };
129 for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
130 if (test[i].env == nullptr) {
131 // Might be set from previous tests.
132 tensorflow::unsetenv("TF_XLA_FLAGS");
133 } else {
134 tensorflow::setenv("TF_XLA_FLAGS", test[i].env, /*overwrite=*/true);
135 }
136 tensorflow::SubProcess child;
137 std::vector<string> argv;
138 argv.push_back(binary_name);
139 argv.push_back("--recursing");
140 if (test[i].arg != nullptr) {
141 argv.push_back(test[i].arg);
142 }
143 child.SetProgram(binary_name, argv);
144 child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
145 child.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
146 CHECK(child.Start()) << "test " << i;
147 string stdout_str;
148 string stderr_str;
149 int child_status = child.Communicate(nullptr, &stdout_str, &stderr_str);
150 CHECK_EQ(child_status, 0) << "test " << i << "\nstdout\n"
151 << stdout_str << "\nstderr\n"
152 << stderr_str;
153 // On windows, we get CR characters. Remove them.
154 stdout_str.erase(std::remove(stdout_str.begin(), stdout_str.end(), '\r'),
155 stdout_str.end());
156 CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
157 }
158 }
159
160 } // namespace xla
161
main(int argc,char * argv[])162 int main(int argc, char* argv[]) {
163 // Save name of binary so that it may invoke itself.
164 xla::binary_name = argv[0];
165 bool recursing = false;
166 xla::int32 int_flag = 1;
167 const std::vector<tensorflow::Flag> flag_list = {
168 tensorflow::Flag("recursing", &recursing,
169 "Whether the binary is being invoked recursively."),
170 tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
171 };
172 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
173 bool parse_ok =
174 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
175 if (!parse_ok) {
176 LOG(QFATAL) << "can't parse from environment\n" << usage;
177 }
178 parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
179 if (!parse_ok) {
180 LOG(QFATAL) << usage;
181 }
182 if (recursing) {
183 printf("%d\n", int_flag);
184 exit(0);
185 }
186 testing::InitGoogleTest(&argc, argv);
187 return RUN_ALL_TESTS();
188 }
189