1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Test for parse_flags_from_env.cc
17 
18 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
19 
20 #include <stdio.h>
21 #include <stdlib.h>
22 
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/subprocess.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/command_line_flags.h"
33 
34 namespace xla {
35 
36 // Test that XLA flags can be set from the environment.
37 // Failure messages are accompanied by the text in msg[].
TestParseFlagsFromEnv(const char * msg)38 static void TestParseFlagsFromEnv(const char* msg) {
39   // Initialize module under test.
40   int* pargc;
41   std::vector<char*>* pargv;
42   ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv);
43 
44   // Check that actual flags can be parsed.
45   bool simple = false;
46   string with_value;
47   string embedded_quotes;
48   string single_quoted;
49   string double_quoted;
50   std::vector<tensorflow::Flag> flag_list = {
51       tensorflow::Flag("simple", &simple, ""),
52       tensorflow::Flag("with_value", &with_value, ""),
53       tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
54       tensorflow::Flag("single_quoted", &single_quoted, ""),
55       tensorflow::Flag("double_quoted", &double_quoted, ""),
56   };
57   bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
58   CHECK_EQ(*pargc, 1) << msg;
59   const std::vector<char*>& argv_second = *pargv;
60   CHECK_NE(argv_second[0], nullptr) << msg;
61   CHECK_EQ(argv_second[1], nullptr) << msg;
62   CHECK(parsed_ok) << msg;
63   CHECK(simple) << msg;
64   CHECK_EQ(with_value, "a_value") << msg;
65   CHECK_EQ(embedded_quotes, "single'double\"") << msg;
66   CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
67   CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
68 }
69 
70 // The flags settings to test.
71 static const char kTestFlagString[] =
72     "--simple "
73     "--with_value=a_value "
74     "--embedded_quotes=single'double\" "
75     "--single_quoted='single quoted \\\\ \n \"' "
76     "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
77 
78 // Test that the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,Basic)79 TEST(ParseFlagsFromEnv, Basic) {
80   // Prepare environment.
81   tensorflow::setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
82   TestParseFlagsFromEnv("(flags in environment variable)");
83 }
84 
85 // Test that a file named by the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,File)86 TEST(ParseFlagsFromEnv, File) {
87   // environment variables where  tmp dir may be specified.
88   static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
89   static const char kTempDir[] = "/tmp";  // default temp dir if all else fails.
90   const char* tmp_dir = nullptr;
91   for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
92     tmp_dir = getenv(kTempVars[i]);
93   }
94   if (tmp_dir == nullptr) {
95     tmp_dir = kTempDir;
96   }
97   string tmp_file =
98       absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
99   FILE* fp = fopen(tmp_file.c_str(), "w");
100   CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
101   for (int i = 0; kTestFlagString[i] != '\0'; i++) {
102     putc(kTestFlagString[i], fp);
103   }
104   fflush(fp);
105   CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
106   fclose(fp);
107   // Prepare environment.
108   tensorflow::setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
109   TestParseFlagsFromEnv("(flags in file)");
110   unlink(tmp_file.c_str());
111 }
112 
113 // Name of the test binary.
114 static const char* binary_name;
115 
116 // Test that when we use both the environment variable and actual
117 // commend line flags (when the latter is possible), the latter win.
TEST(ParseFlagsFromEnv,EnvAndFlag)118 TEST(ParseFlagsFromEnv, EnvAndFlag) {
119   static struct {
120     const char* env;
121     const char* arg;
122     const char* expected_value;
123   } test[] = {
124       {nullptr, nullptr, "1\n"},
125       {nullptr, "--int_flag=2", "2\n"},
126       {"--int_flag=3", nullptr, "3\n"},
127       {"--int_flag=3", "--int_flag=2", "2\n"},  // flag beats environment
128   };
129   for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
130     if (test[i].env == nullptr) {
131       // Might be set from previous tests.
132       tensorflow::unsetenv("TF_XLA_FLAGS");
133     } else {
134       tensorflow::setenv("TF_XLA_FLAGS", test[i].env, /*overwrite=*/true);
135     }
136     tensorflow::SubProcess child;
137     std::vector<string> argv;
138     argv.push_back(binary_name);
139     argv.push_back("--recursing");
140     if (test[i].arg != nullptr) {
141       argv.push_back(test[i].arg);
142     }
143     child.SetProgram(binary_name, argv);
144     child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
145     child.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
146     CHECK(child.Start()) << "test " << i;
147     string stdout_str;
148     string stderr_str;
149     int child_status = child.Communicate(nullptr, &stdout_str, &stderr_str);
150     CHECK_EQ(child_status, 0) << "test " << i << "\nstdout\n"
151                               << stdout_str << "\nstderr\n"
152                               << stderr_str;
153     // On windows, we get CR characters. Remove them.
154     stdout_str.erase(std::remove(stdout_str.begin(), stdout_str.end(), '\r'),
155                      stdout_str.end());
156     CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
157   }
158 }
159 
160 }  // namespace xla
161 
main(int argc,char * argv[])162 int main(int argc, char* argv[]) {
163   // Save name of binary so that it may invoke itself.
164   xla::binary_name = argv[0];
165   bool recursing = false;
166   xla::int32 int_flag = 1;
167   const std::vector<tensorflow::Flag> flag_list = {
168       tensorflow::Flag("recursing", &recursing,
169                        "Whether the binary is being invoked recursively."),
170       tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
171   };
172   xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
173   bool parse_ok =
174       xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
175   if (!parse_ok) {
176     LOG(QFATAL) << "can't parse from environment\n" << usage;
177   }
178   parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
179   if (!parse_ok) {
180     LOG(QFATAL) << usage;
181   }
182   if (recursing) {
183     printf("%d\n", int_flag);
184     exit(0);
185   }
186   testing::InitGoogleTest(&argc, argv);
187   return RUN_ALL_TESTS();
188 }
189