1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file should not have any dependencies apart from the standard library,
17 // as it will be used in OSS outside of this repository.
18
19 #include <algorithm>
20 #include <cstddef>
21 #include <cstdlib>
22 #include <cstring>
23 #include <fstream>
24 #include <iostream>
25 #include <iterator>
26 #include <map>
27 #include <random>
28 #include <regex> // NOLINT
29 #include <sstream>
30 #include <string>
31 #include <type_traits>
32 #include <vector>
33
34 static constexpr int kSeed = 42;
35 static constexpr int kUpperBound = 100;
36 static constexpr int kLowerBound = -100;
37 static constexpr double kLowerBoundFP = -0.1;
38 static constexpr double kUpperBoundFP = 0.1;
39 static const char* const kUsageString = R"(
40 Driver for executing an HLO reproducer in object form in order to let OSS
41 users reproduce the miscompiles.
42
43 Expected workflow:
44
45 1) In the .hlo file, rename the root computation to `EntryModule`.
46 2) Run the .hlo file with XLA_FLAGS=--xla_dump_to set, to obtain the .ll file.
47 3) Compile and link this file with the object file from step (2).
48 4) Run the resulting file with the buffer assignment table as an argument,
49 taken from step 2. The driver will print the output to stderr.
50 5) Compare the output with optimized and non-optimized .ll file from step (2).
51 If the outputs differ, there is a miscompile.
52
53 Run with an environment variable VERBOSE set to see logging.
54 )";
55
56 // Function to be linked with.
57 extern "C" {
58 extern void EntryModule(char* result_buffer, char* run_opts, char** params,
59 char** buffer_table, int* prof_counters);
60 }
61
62 namespace {
63
ExitWithMsg(const std::string & msg)64 [[noreturn]] void ExitWithMsg(const std::string& msg) {
65 std::cerr << msg << std::endl;
66 exit(1);
67 }
68
Check(bool cond,const std::string & msg="Precondition failed")69 void Check(bool cond, const std::string& msg = "Precondition failed") {
70 if (!cond) {
71 ExitWithMsg(msg);
72 }
73 }
74
IsVerbose()75 bool IsVerbose() { return getenv("VERBOSE") != nullptr; }
76
Log(const std::string & msg)77 void Log(const std::string& msg) {
78 if (IsVerbose()) {
79 std::cerr << msg << std::endl;
80 }
81 }
82
83 // Needs to be kept in sync with PrimitiveType in xla_data.proto.
84 enum PrimitiveType {
85 S16 = 0,
86 S32,
87 S64,
88 U8,
89 U16,
90 U32,
91 U64,
92 F16,
93 BF16,
94 F32,
95 F64,
96 C64,
97 C128
98 };
99
primitive_strings()100 const std::vector<std::string>& primitive_strings() {
101 static auto vec = new std::vector<std::string>(
102 {"s16", "s32", "s64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32",
103 "f64", "c64", "c128"});
104 return *vec;
105 }
106
ToString(PrimitiveType type)107 std::string ToString(PrimitiveType type) { return primitive_strings()[type]; }
108
PrimitiveTypeFromString(const std::string & s)109 PrimitiveType PrimitiveTypeFromString(const std::string& s) {
110 const auto& vec = primitive_strings();
111 return static_cast<PrimitiveType>(
112 std::distance(vec.begin(), std::find(vec.begin(), vec.end(), s)));
113 }
114
ByteSize(PrimitiveType type)115 int ByteSize(PrimitiveType type) {
116 std::string s = ToString(type);
117 s = s.substr(1, s.size());
118 return std::stoi(s) / 8;
119 }
120
121 struct ArrayShape {
122 PrimitiveType type;
123 std::vector<int> dimensions;
124 };
125
126 // We support tuples only for output, and we do not support nested tuples.
127 struct TupleShape {
128 std::vector<ArrayShape> elements;
129 };
130
ArrayShapeToString(ArrayShape shape)131 std::string ArrayShapeToString(ArrayShape shape) {
132 std::ostringstream out;
133 out << ToString(shape.type) << "[";
134 for (int i = 0; i < shape.dimensions.size(); i++) {
135 out << std::to_string(shape.dimensions[i]);
136 if (i != shape.dimensions.size() - 1) {
137 out << ",";
138 }
139 }
140 out << "]";
141 return out.str();
142 }
143
144 // Input: TYPE[D1,D2,...DN]
ArrayShapeFromString(const std::string & s)145 ArrayShape ArrayShapeFromString(const std::string& s) {
146 Log("Array shape from string: " + s);
147 Check(s.find('(') == std::string::npos, "Tuple shape is not supported");
148 std::regex shape_r("([^\\[]+)\\[(.*)\\]");
149 std::smatch match;
150 Check(std::regex_match(s, match, shape_r), "Shape not found");
151 std::string type = match[1];
152 std::string dims = match[2];
153 PrimitiveType ptype = PrimitiveTypeFromString(type);
154 std::istringstream dims_stream(dims);
155 std::string dim;
156 std::vector<int> dimensions;
157 while (std::getline(dims_stream, dim, ',')) {
158 dimensions.push_back(std::stoi(dim));
159 }
160 return {ptype, dimensions};
161 }
162
163 // E.g. (f32[10,20], u32[])
TupleShapeFromString(std::string s)164 TupleShape TupleShapeFromString(std::string s) {
165 Log("Tuple shape from string: " + s);
166 if (s[0] != '(') {
167 return {{ArrayShapeFromString(s)}};
168 }
169 s = s.substr(1, s.size() - 2);
170 std::istringstream sstream(s);
171 std::string subshape;
172 std::vector<ArrayShape> out;
173 while (std::getline(sstream, subshape, ' ')) {
174 if (subshape[subshape.size() - 1] == ',') {
175 subshape = subshape.substr(0, subshape.size() - 1);
176 }
177 out.push_back(ArrayShapeFromString(subshape));
178 }
179 return {out};
180 }
181
TupleShapeToString(TupleShape shape)182 std::string TupleShapeToString(TupleShape shape) {
183 std::ostringstream out;
184 if (shape.elements.size() == 1) {
185 return ArrayShapeToString(shape.elements[0]);
186 }
187 out << "(";
188 for (int idx = 0; idx < shape.elements.size(); idx++) {
189 out << ArrayShapeToString(shape.elements[idx]);
190 if (idx != shape.elements.size() - 1) {
191 out << ", ";
192 }
193 }
194 out << ")";
195 return out.str();
196 }
197
198 // Information about the buffer assignment.
199 struct BufferAssignment {
200 // Mapping from allocation index to buffer size (in bytes).
201 std::vector<int> buffers_size;
202
203 // Mapping from allocation index to its shape.
204 std::map<int, TupleShape> buffers_shape;
205
206 // Mapping from param index to allocation index.
207 std::map<int, int> param_to_alloc_idx;
208
209 // Index of the output parameter.
210 int output_idx = -1;
211 };
212
BufferAssignmentToString(const BufferAssignment & assignment)213 std::string BufferAssignmentToString(const BufferAssignment& assignment) {
214 std::ostringstream out;
215 for (const auto& p : assignment.param_to_alloc_idx) {
216 int param_idx = p.first;
217 int allocation_idx = p.second;
218 out << "Param: " << param_idx << " (allocation " << allocation_idx << "): ";
219 auto p2 = assignment.buffers_shape.find(allocation_idx);
220 Check(p2 != assignment.buffers_shape.end(),
221 "Shape not found for parameter: " + std::to_string(param_idx));
222 out << TupleShapeToString(p2->second)
223 << ", size = " << assignment.buffers_size[allocation_idx] << "\n";
224 }
225 return out.str();
226 }
227
228 // RAII table for the given assignment: mapping from a allocation idx to the
229 // actual allocation.
230 class BufferTable {
231 public:
BufferTable(BufferAssignment assignment)232 explicit BufferTable(BufferAssignment assignment) : assignment_(assignment) {
233 int num_buffers = assignment.buffers_size.size();
234 ptr_ = new char*[num_buffers];
235 for (int buffer_idx = 0; buffer_idx < num_buffers; buffer_idx++) {
236 // Call malloc to ensure alignment up to std::max_align_t.
237 ptr_[buffer_idx] =
238 static_cast<char*>(malloc(assignment.buffers_size[buffer_idx]));
239 }
240 }
241
AsPtr()242 char** AsPtr() { return ptr_; }
243
~BufferTable()244 ~BufferTable() {
245 int num_buffers = assignment_.buffers_size.size();
246 for (int buffer_idx = 0; buffer_idx < num_buffers; buffer_idx++) {
247 free(ptr_[buffer_idx]);
248 }
249 delete[] ptr_;
250 }
251
252 private:
253 BufferAssignment assignment_;
254 char** ptr_;
255 };
256
257 // Parse and populate the buffer table;
258 //
259 // Example of input:
260 //
261 // BufferAssignment:
262 // allocation 0: 0x27017c46b600, size 32768, parameter 0, shape f32[256,32] at
263 // ShapeIndex {}:
264 // value: <3 parameter @0> (size=32768,offset=0): f32[256,32]{1,0}
265 // allocation 1: 0x27017c46b6b0, size 128, output shape is f32[32],
266 // maybe-live-out:
267 // value: <5 reduce @0> (size=128,offset=0): f32[32]{0}
268 // allocation 2: 0x27017c46b760, size 4, constant:
269 // value: <4 init_value @0> (size=4,offset=0): f32[]
270 // allocation 3: 0x27017c46b810, size 4, thread-local:
271 // value: <0 x.1 @0> (size=4,offset=0): f32[]
272 // allocation 4: 0x27017c46b8c0, size 4, thread-local:
273 // value: <1 y.1 @0> (size=4,offset=0): f32[]
274 // allocation 5: 0x27017c46b970, size 4, output shape is f32[], thread-local:
275 // value: <2 add.1 @0> (size=4,offset=0): f32[]
ParseBufferAssignment(const std::string & fname)276 BufferAssignment ParseBufferAssignment(const std::string& fname) {
277 BufferAssignment assignment;
278 std::ifstream infile(fname);
279 std::string line;
280 while (std::getline(infile, line)) {
281 std::regex allocation_line_r(
282 "allocation ([0-9]+): .+, size ([0-9]+), (.+)");
283 std::smatch match;
284 if (std::regex_search(line, match, allocation_line_r)) {
285 Log("Matched allocation description: " + line);
286 int allocation_idx = std::stoi(match[1]);
287 int size = std::stoi(match[2]);
288 Log("Allocation size = " + std::to_string(size));
289 const std::string& postfix = match[3];
290 Check(allocation_idx == assignment.buffers_size.size(),
291 "Unordered allocations in input");
292 assignment.buffers_size.push_back(size);
293
294 std::regex output_r("output shape is \\|([^\\|]+)\\|,");
295 std::smatch output_match;
296 if (std::regex_search(postfix, output_match, output_r)) {
297 Log("Matched out parameter: " + postfix);
298 Check(assignment.output_idx == -1, "Multiple out-parameters");
299 assignment.output_idx = allocation_idx;
300 std::string output_shape = output_match[1];
301 Log("output shape = " + output_shape);
302 TupleShape shape = TupleShapeFromString(output_shape);
303 assignment.buffers_shape[allocation_idx] = shape;
304 Log("parsed output shape = " + TupleShapeToString(shape));
305 }
306
307 std::regex parameter_r("parameter ([0-9]+), shape \\|([^\\|]+)\\|");
308 std::smatch param_match;
309 if (std::regex_search(postfix, param_match, parameter_r)) {
310 Log("Matched parameter description: " + postfix);
311 int param_idx = std::stoi(param_match[1]);
312 assignment.param_to_alloc_idx[param_idx] = allocation_idx;
313 std::string param_shape = param_match[2];
314 TupleShape shape = TupleShapeFromString(param_shape);
315 assignment.buffers_shape[allocation_idx] = shape;
316 Log("parsed parameter shape for param " + std::to_string(param_idx) +
317 " = " + TupleShapeToString(shape));
318 }
319 }
320 }
321 Check(assignment.output_idx != -1, "Output not set");
322 return assignment;
323 }
324
GetNumElements(const ArrayShape & shape)325 int GetNumElements(const ArrayShape& shape) {
326 int num_elements = 1;
327 for (int dim : shape.dimensions) {
328 num_elements *= dim;
329 }
330 return num_elements;
331 }
332
333 template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
FillIntT(void * buffer,int num_elements)334 void FillIntT(void* buffer, int num_elements) {
335 std::mt19937 generator(kSeed);
336 T* casted = static_cast<T*>(buffer);
337 std::uniform_int_distribution<T> distr(kLowerBound, kUpperBound);
338 for (int i = 0; i < num_elements; i++) {
339 casted[i] = distr(generator);
340 }
341 }
342
343 template <typename T,
344 typename = std::enable_if_t<std::is_floating_point<T>::value>>
FillFloatT(void * buffer,int num_elements)345 void FillFloatT(void* buffer, int num_elements) {
346 std::mt19937 generator(kSeed);
347 T* casted = static_cast<T*>(buffer);
348 std::uniform_real_distribution<T> distr(kLowerBoundFP, kUpperBoundFP);
349 for (int i = 0; i < num_elements; i++) {
350 casted[i] = distr(generator);
351 }
352 }
353
Fill(void * buffer,const ArrayShape & shape)354 void Fill(void* buffer, const ArrayShape& shape) {
355 int num_elements = GetNumElements(shape);
356 Log("Number of elements = " + std::to_string(num_elements));
357 Log("Shape type = " + ToString(shape.type) +
358 ", shape = " + ArrayShapeToString(shape));
359 switch (shape.type) {
360 case S16:
361 return FillIntT<short>(buffer, num_elements); // NOLINT
362 case S32:
363 return FillIntT<int>(buffer, num_elements);
364 case S64:
365 return FillIntT<long long>(buffer, num_elements); // NOLINT
366 case U8:
367 return FillIntT<unsigned char>(buffer, num_elements);
368 case U16:
369 return FillIntT<unsigned short>(buffer, num_elements); // NOLINT
370 case U32:
371 return FillIntT<unsigned int>(buffer, num_elements);
372 case U64:
373 return FillIntT<unsigned long long>(buffer, num_elements); // NOLINT
374 case F32:
375 return FillFloatT<float>(buffer, num_elements);
376 case F64:
377 return FillFloatT<double>(buffer, num_elements);
378
379 case F16:
380 case BF16:
381 case C64:
382 case C128:
383 ExitWithMsg("Unsupported type: " + ToString(shape.type));
384 }
385 }
386
387 template <typename T>
388 #if defined(MEMORY_SANITIZER)
389 __attribute__((no_sanitize_memory))
390 #endif
DisplayT(const void * buffer,int num_elements)391 void DisplayT(const void* buffer, int num_elements) {
392 const T* casted = static_cast<const T*>(buffer);
393 for (int i = 0; i < num_elements; i++) {
394 std::cout << casted[i];
395 if (i != num_elements - 1) {
396 std::cout << ", ";
397 }
398 }
399 std::cout << std::endl;
400 }
401
Display(const void * buffer,const ArrayShape & shape)402 void Display(const void* buffer, const ArrayShape& shape) {
403 int num_elements = GetNumElements(shape);
404 switch (shape.type) {
405 case S16:
406 return DisplayT<short>(buffer, num_elements); // NOLINT
407 case S32:
408 return DisplayT<int>(buffer, num_elements);
409 case S64:
410 return DisplayT<long long>(buffer, num_elements); // NOLINT
411 case U8:
412 return DisplayT<unsigned char>(buffer, num_elements);
413 case U16:
414 return DisplayT<unsigned short>(buffer, num_elements); // NOLINT
415 case U32:
416 return DisplayT<unsigned int>(buffer, num_elements);
417 case U64:
418 return DisplayT<unsigned long long>(buffer, num_elements); // NOLINT
419 case F32:
420 return DisplayT<float>(buffer, num_elements);
421 case F64:
422 return DisplayT<double>(buffer, num_elements);
423
424 case F16:
425 case BF16:
426 case C64:
427 case C128:
428 ExitWithMsg("Unsupported type: " + ToString(shape.type));
429 }
430 }
431
Display(const void * buffer,const TupleShape & shape)432 void Display(const void* buffer, const TupleShape& shape) {
433 if (shape.elements.size() == 1) {
434 return Display(buffer, shape.elements[0]);
435 }
436 std::cout << "(" << std::endl;
437 auto casted = static_cast<const void* const*>(buffer);
438 for (int tuple_idx = 0; tuple_idx < shape.elements.size(); tuple_idx++) {
439 ArrayShape array_shape = shape.elements[tuple_idx];
440 Display(casted[tuple_idx], array_shape);
441 if (tuple_idx != shape.elements.size() - 1) {
442 std::cout << ", " << std::endl;
443 }
444 }
445 std::cout << ")" << std::endl;
446 }
447
448 } // end namespace
449
main(int argc,char ** argv)450 int main(int argc, char** argv) {
451 if (argc < 2) {
452 ExitWithMsg(
453 "Please provide buffer table filename as an argument, "
454 "or invoke with --help for usage instructions.");
455 }
456 std::string arg = argv[1];
457 if (arg == "--help") {
458 std::cout << kUsageString << std::endl;
459 return 0;
460 }
461
462 BufferAssignment assignment = ParseBufferAssignment(arg);
463 Log("Buffer assignment: \n" + BufferAssignmentToString(assignment));
464 BufferTable table(assignment);
465
466 // Fill out input parameters.
467 for (const auto& p : assignment.param_to_alloc_idx) {
468 int param_idx = p.first;
469 int allocation_idx = p.second;
470 TupleShape tuple_shape = assignment.buffers_shape[allocation_idx];
471 Check(tuple_shape.elements.size() == 1,
472 "Parameters can not be tuples, got shape: " +
473 TupleShapeToString(tuple_shape));
474 ArrayShape shape = tuple_shape.elements[0];
475 Check(GetNumElements(shape) ==
476 assignment.buffers_size[allocation_idx] / ByteSize(shape.type),
477 "Unexpected number of elements");
478 Fill(table.AsPtr()[allocation_idx], shape);
479
480 if (IsVerbose()) {
481 std::cout << "Filled parameter buffer for param " << param_idx << ": "
482 << std::endl;
483 Display(table.AsPtr()[allocation_idx], shape);
484 }
485 }
486
487 Log("Launching module");
488 EntryModule(/*result_buffer=*/nullptr,
489 /*run_opts=*/nullptr,
490 /*params=*/nullptr, table.AsPtr(),
491 /*prof_counters=*/nullptr);
492
493 std::cout << "Output:" << std::endl;
494 Log("Output shape: " +
495 TupleShapeToString(assignment.buffers_shape[assignment.output_idx]));
496 Display(table.AsPtr()[assignment.output_idx],
497 assignment.buffers_shape[assignment.output_idx]);
498 }
499