1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #include <atomic>
19
20 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
23 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/stream_executor.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/public/version.h"
36
37 namespace tensorflow {
38
39 // Subclass EventMgr to access its private constructor.
40 class TEST_EventMgr : public EventMgr {
41 public:
TEST_EventMgr(se::StreamExecutor * se,const GPUOptions & gpu_options)42 TEST_EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
43 : EventMgr(se, gpu_options) {}
44 };
45
46 class TEST_EventMgrHelper {
47 public:
TEST_EventMgrHelper(EventMgr * em)48 explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {
49 // The polling loop can interfere with the measurements made here, and
50 // isn't needed since the member PollEvents() always clears the queue.
51 // The tested behavior is slightly different from what may occur in
52 // ordinary execution.
53 StopPollingLoop();
54 }
55
queue_size()56 size_t queue_size() {
57 mutex_lock l(em_->mu_);
58 return em_->used_events_.size();
59 }
60
free_size()61 size_t free_size() {
62 mutex_lock l(em_->mu_);
63 return em_->free_events_.size();
64 }
65
PollEvents()66 void PollEvents() {
67 while (queue_size() > 0) {
68 // For ordinary tensor frees, this function
69 // should synchronously harvest all complete
70 // events and execute the corresponding memory frees.
71 EventMgr::ToFreeVector to_free;
72 {
73 mutex_lock l(em_->mu_);
74 em_->PollEvents(true, &to_free);
75 }
76 em_->FreeMemory(to_free);
77 }
78 }
79
StopPollingLoop()80 void StopPollingLoop() { return em_->StopPollingLoop(); }
81
StartPollingLoop()82 void StartPollingLoop() { return em_->StartPollingLoop(); }
83
84 private:
85 EventMgr* em_;
86 };
87
88 static std::atomic_int_fast64_t live_tensor_bytes(0);
89
90 // A TensorBuffer that counts live memory usage for testing
91 class TestTensorBuffer : public TensorBuffer {
92 public:
TestTensorBuffer(size_t bytes)93 explicit TestTensorBuffer(size_t bytes)
94 : TensorBuffer(nullptr), bytes_(bytes) {
95 live_tensor_bytes += bytes_;
96 }
~TestTensorBuffer()97 ~TestTensorBuffer() override { live_tensor_bytes -= bytes_; }
98
size() const99 size_t size() const override { return bytes_; }
100
101 // Not used in this test
root_buffer()102 TensorBuffer* root_buffer() override { return nullptr; }
FillAllocationDescription(AllocationDescription * arg) const103 void FillAllocationDescription(AllocationDescription* arg) const override {}
104
105 private:
106 size_t bytes_;
107 };
108
109 namespace {
110
TEST(EventMgr,Empty)111 TEST(EventMgr, Empty) {
112 auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
113 TEST_EventMgr em(stream_exec, GPUOptions());
114 TEST_EventMgrHelper th(&em);
115 EXPECT_EQ(0, th.queue_size());
116 EXPECT_EQ(0, th.free_size());
117 }
118
119 // Tests that WarnIfInCallback() triggers correctly.
TEST(EventMgr,WarnIfInCallback)120 TEST(EventMgr, WarnIfInCallback) {
121 auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
122 TEST_EventMgr em(stream_exec, GPUOptions());
123 TEST_EventMgrHelper th(&em);
124 std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
125 CHECK(stream);
126 stream->Init();
127 bool hit = false;
128 th.StartPollingLoop();
129 device_event_mgr::WarnIfInCallback([&hit] { hit = true; });
130 EXPECT_FALSE(hit);
131 Notification note;
132 em.ThenExecute(stream.get(), [&hit, ¬e]() {
133 device_event_mgr::WarnIfInCallback([&hit, ¬e] {
134 hit = true;
135 note.Notify();
136 });
137 });
138 note.WaitForNotification();
139 EXPECT_TRUE(hit);
140 }
141 } // namespace
142
143 // Provides access to private resources of BaseGPUDevice.
144 class GPUDeviceTestHelper {
145 public:
GPUDeviceTestHelper(size_t memory_limit,int pending_cap)146 GPUDeviceTestHelper(size_t memory_limit, int pending_cap) {
147 SessionOptions sops;
148 device_ =
149 DeviceFactory::NewDevice(DEVICE_GPU, sops, "/job:a/replica:0/task:0");
150 gpu_.reset(reinterpret_cast<BaseGPUDevice*>(device_.release()));
151 gpu_allocator_ = GPUProcessState::singleton()->GetGPUAllocator(
152 GPUOptions(), TfGpuId(0), memory_limit, /*peer_gpu_ids=*/{});
153 host_allocator_ = GPUProcessState::singleton()->GetGpuHostAllocator(0);
154 }
155
gpu()156 BaseGPUDevice* gpu() { return gpu_.get(); }
gpu_allocator()157 Allocator* gpu_allocator() { return gpu_allocator_; }
host_allocator()158 Allocator* host_allocator() { return host_allocator_; }
compute_stream()159 se::Stream* compute_stream() { return gpu_->stream_->compute; }
h2d_stream()160 se::Stream* h2d_stream() { return gpu_->stream_->host_to_device; }
d2h_stream()161 se::Stream* d2h_stream() { return gpu_->stream_->device_to_host; }
d2d_stream()162 se::Stream* d2d_stream() { return gpu_->stream_->device_to_device[0]; }
event_mgr()163 EventMgr* event_mgr() { return gpu_->em_; }
pending_cap()164 int pending_cap() { return gpu_->pending_cap_; }
165
166 private:
167 std::unique_ptr<Device> device_;
168 std::unique_ptr<BaseGPUDevice> gpu_;
169 Allocator* gpu_allocator_;
170 Allocator* host_allocator_;
171 };
172
173 namespace {
174
175 // Class that can queue some GPU data transfers and simple kernels.
176 class EMBenchmarkHelper {
177 GPUDeviceTestHelper* gpu_helper_;
178 // We need one of these for each Add op in the chain.
179 std::vector<std::unique_ptr<OpKernel>> add_kernels_;
180 std::vector<OpKernelContext::Params*> add_params_;
181 std::vector<std::unique_ptr<OpKernelContext>> add_contexts_;
182 // The rest of these are one per chain.
183 NodeDef add_node_def_;
184 NodeDef id_node_def_;
185 gtl::InlinedVector<TensorValue, 4> add_inputs_;
186 std::vector<AllocatorAttributes> allocator_attrs_;
187 gtl::InlinedVector<Tensor, 4> gpu_inputs_;
188 gtl::InlinedVector<Tensor, 4> gpu_outputs_;
189 gtl::InlinedVector<Tensor, 4> host_inputs_;
190 gtl::InlinedVector<Tensor, 4> host_outputs_;
191
192 public:
193 // Length of tensors. TODO(tucker): make this a variable parameter.
194 static constexpr int kTDim = 1024;
195
num_ops() const196 int num_ops() const { return add_kernels_.size(); }
tensor_size() const197 size_t tensor_size() const {
198 return add_inputs_.empty() ? 0 : add_inputs_[0]->NumElements();
199 }
200
host_outputs(int i)201 Tensor& host_outputs(int i) { return host_outputs_[i]; }
host_inputs(int i)202 Tensor& host_inputs(int i) { return host_inputs_[i]; }
203
EMBenchmarkHelper(GPUDeviceTestHelper * h)204 EMBenchmarkHelper(GPUDeviceTestHelper* h) : gpu_helper_(h) {}
205
ReInit(int num_ops,int tensor_size)206 void ReInit(int num_ops, int tensor_size) {
207 gpu_inputs_.clear();
208 while (gpu_inputs_.size() < 2) {
209 gpu_inputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
210 {tensor_size}, AllocationAttributes()));
211 }
212 gpu_outputs_.clear();
213 while (gpu_outputs_.size() < 1) {
214 gpu_outputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
215 {tensor_size}, AllocationAttributes()));
216 }
217 host_inputs_.clear();
218 while (host_inputs_.size() < 2) {
219 int instance_index = host_inputs_.size();
220 host_inputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
221 {tensor_size}, AllocationAttributes()));
222 for (int i = 0; i < tensor_size; ++i) {
223 host_inputs_.back().flat<float>()(i) =
224 i * (1.0 + (0.5 * instance_index));
225 }
226 }
227 host_outputs_.clear();
228 while (host_outputs_.size() < 1) {
229 host_outputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
230 {tensor_size}, AllocationAttributes()));
231 for (int i = 0; i < tensor_size; ++i) {
232 host_outputs_.back().flat<float>()(i) = -1;
233 }
234 }
235 add_kernels_.clear();
236 add_params_.clear();
237 while (add_kernels_.size() < num_ops) {
238 MakeAddOp();
239 }
240 }
241
GetOpKernel(const NodeDef & node_def,Status * status)242 std::unique_ptr<OpKernel> GetOpKernel(const NodeDef& node_def,
243 Status* status) {
244 return CreateOpKernel("GPU", gpu_helper_->gpu(),
245 gpu_helper_->gpu_allocator(), node_def,
246 TF_GRAPH_DEF_VERSION, status);
247 }
248
MakeAddOp()249 void MakeAddOp() {
250 if (add_kernels_.empty()) {
251 TF_ASSERT_OK(NodeDefBuilder("add_op", "Add")
252 .Input(FakeInput(DT_FLOAT))
253 .Input(FakeInput(DT_FLOAT))
254 .Device("/job:a/replica:0/task:0/GPU:0")
255 .Finalize(&add_node_def_));
256 }
257 Status status;
258 add_kernels_.emplace_back(GetOpKernel(add_node_def_, &status));
259 TF_ASSERT_OK(status);
260 add_params_.push_back(new OpKernelContext::Params);
261 PrepOpKernel(add_params_.back(), add_kernels_.back().get());
262 }
263
SetOutputAttrs(OpKernelContext::Params * params,std::vector<AllocatorAttributes> * attrs)264 void SetOutputAttrs(OpKernelContext::Params* params,
265 std::vector<AllocatorAttributes>* attrs) {
266 attrs->clear();
267 for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
268 AllocatorAttributes attr;
269 const bool on_host =
270 (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
271 attr.set_on_host(on_host);
272 attrs->push_back(attr);
273 }
274 params->output_attr_array = attrs->data();
275 params->forward_from_array = {};
276 }
277
PrepOpKernel(OpKernelContext::Params * params,OpKernel * kernel)278 void PrepOpKernel(OpKernelContext::Params* params, OpKernel* kernel) {
279 // This mimics what happens in ExecutorState::Process to run
280 // a single graph node.
281 params->step_id = 1;
282 params->device = gpu_helper_->gpu();
283 params->log_memory = false;
284 params->rendezvous = nullptr;
285 params->collective_executor = nullptr;
286 params->session_state = nullptr; // ???
287 params->session_handle = "session_handle";
288 params->tensor_store = nullptr;
289 params->cancellation_manager = nullptr;
290
291 params->call_frame = nullptr;
292 params->function_library = nullptr;
293 params->runner = nullptr;
294 params->graph_collector = nullptr;
295
296 params->step_container = nullptr;
297 params->slice_reader_cache = nullptr;
298 params->resource_manager = gpu_helper_->gpu()->resource_manager();
299
300 params->stats_collector = nullptr;
301 params->inc_num_deferred_ops_function = nullptr;
302 params->dec_num_deferred_ops_function = nullptr;
303
304 params->op_device_context = nullptr;
305 params->track_allocations = false;
306 params->op_kernel = kernel;
307 params->frame_iter = FrameAndIter(0, 0);
308 params->is_input_dead = false;
309
310 if (add_inputs_.empty()) {
311 add_inputs_.resize(2);
312 add_inputs_[0] = TensorValue(&gpu_inputs_[0]);
313 add_inputs_[1] = TensorValue(&gpu_inputs_[1]);
314 }
315 params->inputs = &add_inputs_;
316 params->input_alloc_attrs = nullptr;
317 SetOutputAttrs(params, &allocator_attrs_);
318 }
319
320 struct TimeSet {
321 int iter = 0;
322 int64 start = 0;
323 int64 copy_done = 0;
324 int64 compute_done = 0;
325 int64 final_copy = 0;
326 int64 all_done = 0;
327 };
328
329 // Display sampled iteration times giving the approximate breakdown
330 // within iterations and overall curve.
DisplayTimes(std::vector<TimeSet> * times)331 void DisplayTimes(std::vector<TimeSet>* times) {
332 LOG(INFO) << "Summarize set of " << times->size() << " iters";
333 for (auto& ts : *times) {
334 ts.final_copy = ts.all_done - ts.compute_done;
335 ts.compute_done = ts.compute_done - ts.copy_done;
336 ts.copy_done = ts.copy_done - ts.start;
337 ts.all_done = ts.all_done - ts.start;
338 }
339 struct TSSort {
340 bool operator()(const TimeSet& a, const TimeSet& b) {
341 return a.all_done < b.all_done;
342 }
343 };
344 std::sort(times->begin(), times->end(), TSSort());
345 int64 last_time = 0;
346 // Display first, last and every > 5% change.
347 for (int i = 0; i < times->size(); ++i) {
348 if (i == (times->size() - 1) ||
349 (times->at(i).all_done >= (1.05 * last_time))) {
350 LOG(INFO) << "rank " << i << " iter: " << times->at(i).iter
351 << " copy: " << times->at(i).copy_done
352 << " compute: " << times->at(i).compute_done
353 << " copy back: " << times->at(i).final_copy
354 << " sum: " << times->at(i).all_done;
355 last_time = times->at(i).all_done;
356 }
357 }
358 }
359
360 // Queue one work unit on the GPU as follows:
361 // 1. Copy 2 input tensors from CPU to GPU using h2d stream.
362 // 2. Instruct compute stream to wait on h2d stream.
363 // 3. Queue a sequence of Add ops on the compute stream, all using
364 // the same input tensors, allocating their own output tensors.
365 // 4. Instruct d2h stream to wait on the compute stream.
366 // 5. Copy final output tensor back to the CPU.
367 // 6. Instruct the EventMgr to execute callback when the final tensor
368 // copy completes.
369 // If event_after_add == true then additionally instruct the EventMgr
370 // to execute the callback after each Add completes.
371 // The optional times parameter is used for gathering detailed timing
372 // data.
DoAddChain(int adds_per_copy,int rounds,bool event_after_add,std::function<void ()> callback,std::vector<TimeSet> * times)373 void DoAddChain(int adds_per_copy, int rounds, bool event_after_add,
374 std::function<void()> callback, std::vector<TimeSet>* times) {
375 // Take an extra ref on the inputs so that the add doesn't compute in place.
376 Tensor alias0(gpu_inputs_[0]);
377 Tensor alias1(gpu_inputs_[1]);
378 for (int r = 0; r < rounds; ++r) {
379 if (times) {
380 times->at(r).iter = r;
381 times->at(r).start = Env::Default()->NowMicros();
382 }
383 gpu_helper_->h2d_stream()->ThenWaitFor(gpu_helper_->compute_stream());
384 // Begin by copying the input values from CPU to GPU.
385 const int64 src_bytes = host_inputs_[0].TotalBytes();
386 se::DeviceMemoryBase gpu_dst_ptr0(DMAHelper::base(&gpu_inputs_[0]),
387 src_bytes);
388 gpu_helper_->h2d_stream()->ThenMemcpy(
389 &gpu_dst_ptr0, DMAHelper::base(&host_inputs_[0]), src_bytes);
390 se::DeviceMemoryBase gpu_dst_ptr1(DMAHelper::base(&gpu_inputs_[1]),
391 src_bytes);
392 gpu_helper_->h2d_stream()->ThenMemcpy(
393 &gpu_dst_ptr1, DMAHelper::base(&host_inputs_[1]), src_bytes);
394 gpu_helper_->compute_stream()->ThenWaitFor(gpu_helper_->h2d_stream());
395 if (times) {
396 gpu_helper_->event_mgr()->ThenExecute(
397 gpu_helper_->compute_stream(), [times, r]() {
398 times->at(r).copy_done = Env::Default()->NowMicros();
399 });
400 }
401 std::unique_ptr<OpKernelContext> ctx;
402 for (int apc = 0; apc < adds_per_copy; ++apc) {
403 ctx.reset(new OpKernelContext(add_params_[apc], 1));
404 gpu_helper_->gpu()->Compute(add_kernels_[apc].get(), ctx.get());
405 TF_ASSERT_OK(ctx->status());
406 if (event_after_add) {
407 gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->compute_stream(),
408 callback);
409 }
410 }
411 // Finish by copying output back to CPU.
412 if (times) {
413 gpu_helper_->event_mgr()->ThenExecute(
414 gpu_helper_->compute_stream(), [times, r]() {
415 times->at(r).compute_done = Env::Default()->NowMicros();
416 });
417 }
418 gpu_helper_->d2h_stream()->ThenWaitFor(gpu_helper_->compute_stream());
419 const int64 return_bytes = ctx->mutable_output(0)->TotalBytes();
420 se::DeviceMemoryBase gpu_src_ptr(DMAHelper::base(ctx->mutable_output(0)),
421 return_bytes);
422 gpu_helper_->d2h_stream()->ThenMemcpy(DMAHelper::base(&host_outputs_[0]),
423 gpu_src_ptr, return_bytes);
424 gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->d2h_stream(),
425 callback);
426 if (times) {
427 gpu_helper_->event_mgr()->ThenExecute(
428 gpu_helper_->d2h_stream(), [times, r]() {
429 times->at(r).all_done = Env::Default()->NowMicros();
430 });
431 }
432 }
433 }
434 };
435
BM_no_ops(int iters,int threads)436 static void BM_no_ops(int iters, int threads) {
437 testing::StopTiming();
438 #ifdef PLATFORM_GOOGLE
439 BenchmarkUseRealTime();
440 #else
441 testing::UseRealTime();
442 #endif // PLATFORM_GOOGLE
443 auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
444 std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
445 CHECK(stream);
446 stream->Init();
447 TEST_EventMgr em(stream_exec, GPUOptions());
448 testing::StartTiming();
449 std::atomic<int> counter;
450 counter.store(0, std::memory_order_seq_cst);
451 se::Stream* stream_ptr = stream.get();
452 auto runner = [&em, &counter, stream_ptr, iters]() {
453 auto callback = [&counter]() { counter.fetch_add(1); };
454 for (int i = 0; i < iters; ++i) {
455 em.ThenExecute(stream_ptr, callback);
456 }
457 };
458 for (int t = 0; t < threads; ++t) {
459 Env::Default()->SchedClosure(runner);
460 }
461 int expected = iters * threads;
462 while (counter < expected) {
463 Env::Default()->SleepForMicroseconds(1);
464 }
465 }
466 BENCHMARK(BM_no_ops)->Arg(4);
467 BENCHMARK(BM_no_ops)->Arg(8);
468 BENCHMARK(BM_no_ops)->Arg(32);
469
470 // Benchmark functions are defined at top level. In order to provide a real,
471 // persistent GPUDevice to the following function it also needs to be at top
472 // level. But then we can't clean it up without a cuda runtime error, so we
473 // just leak it.
474 GPUDeviceTestHelper* gpu_helper = nullptr;
475 EMBenchmarkHelper* bm_helper = nullptr;
476 mutex helper_mu;
477
478 #ifdef PLATFORM_GOOGLE
BM_chain_ops(int iters,int tensor_size,int adds_per_round,bool event_after_add,int pending_cap)479 static void BM_chain_ops(int iters, int tensor_size, int adds_per_round,
480 bool event_after_add, int pending_cap) {
481 #else
482 static void BM_chain_ops(int iters, int tensor_size, int adds_per_round,
483 bool event_after_add, int pending_cap, int threads) {
484 #endif
485 testing::StopTiming();
486 #ifdef PLATFORM_GOOGLE
487 BenchmarkUseRealTime();
488 #else
489 testing::UseRealTime();
490 #endif // PLATFORM_GOOGLE
491 {
492 mutex_lock l(helper_mu);
493 if (gpu_helper && gpu_helper->pending_cap() != pending_cap) {
494 delete bm_helper;
495 bm_helper = nullptr;
496 delete gpu_helper;
497 gpu_helper = nullptr;
498 }
499 if (!gpu_helper) {
500 gpu_helper = new GPUDeviceTestHelper(1 << 24, pending_cap);
501 bm_helper = new EMBenchmarkHelper(gpu_helper);
502 }
503 if (bm_helper->num_ops() != adds_per_round ||
504 bm_helper->tensor_size() != tensor_size) {
505 bm_helper->ReInit(adds_per_round, tensor_size);
506 }
507 }
508 std::vector<EMBenchmarkHelper::TimeSet> times;
509 std::vector<EMBenchmarkHelper::TimeSet>* time_ptr = nullptr;
510 if (VLOG_IS_ON(1)) {
511 times.resize(iters);
512 time_ptr = ×
513 }
514 std::atomic<int> counter;
515 counter.store(0, std::memory_order_seq_cst);
516 auto callback = [&counter]() { counter.fetch_add(1); };
517 // First iter is always slow, so do one prior to the timed loop.
518 int expected = 1 + (event_after_add ? adds_per_round : 0);
519 bm_helper->DoAddChain(adds_per_round, 1, event_after_add, callback, nullptr);
520 while (counter < expected) {
521 Env::Default()->SleepForMicroseconds(1);
522 }
523 counter = 0;
524 testing::StartTiming();
525 #ifdef PLATFORM_GOOGLE
526 expected = iters * (1 + (event_after_add ? adds_per_round : 0));
527 bm_helper->DoAddChain(adds_per_round, iters, event_after_add, callback,
528 time_ptr);
529 #else
530 expected = threads * iters * (1 + (event_after_add ? adds_per_round : 0));
531 for (int i = 0; i < threads; ++i) {
532 Env::Default()->SchedClosure(
533 [callback, iters, adds_per_round, event_after_add, time_ptr]() {
534 bm_helper->DoAddChain(adds_per_round, iters, event_after_add,
535 callback, time_ptr);
536 });
537 }
538 #endif
539 while (counter < expected) {
540 Env::Default()->SleepForMicroseconds(1);
541 }
542 testing::StopTiming();
543 VLOG(1) << "counter = " << counter << " post_execute Output: "
544 << bm_helper->host_outputs(0).SummarizeValue(64);
545 if (time_ptr) bm_helper->DisplayTimes(time_ptr);
546 }
547
548 #ifdef PLATFORM_GOOGLE
549 static void BM_chain_1024_1_false(int iters) {
550 BM_chain_ops(iters, 1024, 1, false, 0);
551 }
552
553 static void BM_chain_1024_1_true(int iters) {
554 BM_chain_ops(iters, 1024, 1, true, 0);
555 }
556
557 static void BM_chain_1024_10_false(int iters) {
558 BM_chain_ops(iters, 1024, 10, false, 0);
559 }
560
561 static void BM_chain_1024_10_true(int iters) {
562 BM_chain_ops(iters, 1024, 10, true, 0);
563 }
564
565 static void BM_chain_1024_100_false(int iters) {
566 BM_chain_ops(iters, 1024, 100, false, 0);
567 }
568
569 static void BM_chain_1024_100_true(int iters) {
570 BM_chain_ops(iters, 1024, 100, true, 0);
571 }
572
573 static void BM_chain_1M_1_false(int iters) {
574 BM_chain_ops(iters, 1 << 20, 1, false, 0);
575 }
576
577 static void BM_chain_1M_1_true(int iters) {
578 BM_chain_ops(iters, 1 << 20, 1, true, 0);
579 }
580
581 static void BM_chain_1M_10_false(int iters) {
582 BM_chain_ops(iters, 1 << 20, 10, false, 0);
583 }
584
585 static void BM_chain_1M_10_true(int iters) {
586 BM_chain_ops(iters, 1 << 20, 10, true, 0);
587 }
588
589 static void BM_chain_1M_100_false(int iters) {
590 BM_chain_ops(iters, 1 << 20, 100, false, 0);
591 }
592
593 static void BM_chain_1M_100_true(int iters) {
594 BM_chain_ops(iters, 1 << 20, 100, true, 0);
595 }
596
597 BENCHMARK(BM_chain_1024_1_false)->Threads(1);
598 BENCHMARK(BM_chain_1024_1_true)->Threads(1);
599 BENCHMARK(BM_chain_1024_1_false)->Threads(2);
600 BENCHMARK(BM_chain_1024_1_true)->Threads(2);
601 BENCHMARK(BM_chain_1024_1_false)->Threads(8);
602 BENCHMARK(BM_chain_1024_1_true)->Threads(8);
603 BENCHMARK(BM_chain_1024_10_false)->Threads(1);
604 BENCHMARK(BM_chain_1024_10_true)->Threads(1);
605 BENCHMARK(BM_chain_1024_10_false)->Threads(8);
606 BENCHMARK(BM_chain_1024_10_true)->Threads(8);
607 BENCHMARK(BM_chain_1024_100_false)->Threads(1);
608 BENCHMARK(BM_chain_1024_100_true)->Threads(1);
609 BENCHMARK(BM_chain_1024_100_false)->Threads(2);
610 BENCHMARK(BM_chain_1024_100_true)->Threads(2);
611 BENCHMARK(BM_chain_1024_100_false)->Threads(8);
612 BENCHMARK(BM_chain_1024_100_true)->Threads(8);
613
614 BENCHMARK(BM_chain_1M_1_false)->Threads(1);
615 BENCHMARK(BM_chain_1M_1_true)->Threads(1);
616 BENCHMARK(BM_chain_1M_1_false)->Threads(2);
617 BENCHMARK(BM_chain_1M_1_true)->Threads(2);
618 BENCHMARK(BM_chain_1M_1_false)->Threads(8);
619 BENCHMARK(BM_chain_1M_1_true)->Threads(8);
620 BENCHMARK(BM_chain_1M_10_false)->Threads(1);
621 BENCHMARK(BM_chain_1M_10_true)->Threads(1);
622 BENCHMARK(BM_chain_1M_10_false)->Threads(8);
623 BENCHMARK(BM_chain_1M_10_true)->Threads(8);
624 BENCHMARK(BM_chain_1M_100_false)->Threads(1);
625 BENCHMARK(BM_chain_1M_100_true)->Threads(1);
626 BENCHMARK(BM_chain_1M_100_false)->Threads(2);
627 BENCHMARK(BM_chain_1M_100_true)->Threads(2);
628 BENCHMARK(BM_chain_1M_100_false)->Threads(8);
629 BENCHMARK(BM_chain_1M_100_true)->Threads(8);
630 #else
631 static void BM_chain_1024_1_false(int iters, int threads) {
632 BM_chain_ops(iters, 1024, 1, false, 0, threads);
633 }
634
635 static void BM_chain_1024_1_true(int iters, int threads) {
636 BM_chain_ops(iters, 1024, 1, true, 0, threads);
637 }
638
639 static void BM_chain_1024_10_false(int iters, int threads) {
640 BM_chain_ops(iters, 1024, 10, false, 0, threads);
641 }
642
643 static void BM_chain_1024_10_true(int iters, int threads) {
644 BM_chain_ops(iters, 1024, 10, true, 0, threads);
645 }
646
647 static void BM_chain_1024_100_false(int iters, int threads) {
648 BM_chain_ops(iters, 1024, 100, false, 0, threads);
649 }
650
651 static void BM_chain_1024_100_true(int iters, int threads) {
652 BM_chain_ops(iters, 1024, 100, true, 0, threads);
653 }
654
655 static void BM_chain_1M_1_false(int iters, int threads) {
656 BM_chain_ops(iters, 1 << 20, 1, false, 0, threads);
657 }
658
659 static void BM_chain_1M_1_true(int iters, int threads) {
660 BM_chain_ops(iters, 1 << 20, 1, true, 0, threads);
661 }
662
663 static void BM_chain_1M_10_false(int iters, int threads) {
664 BM_chain_ops(iters, 1 << 20, 10, false, 0, threads);
665 }
666
667 static void BM_chain_1M_10_true(int iters, int threads) {
668 BM_chain_ops(iters, 1 << 20, 10, true, 0, threads);
669 }
670
671 static void BM_chain_1M_100_false(int iters, int threads) {
672 BM_chain_ops(iters, 1 << 20, 100, false, 0, threads);
673 }
674
675 static void BM_chain_1M_100_true(int iters, int threads) {
676 BM_chain_ops(iters, 1 << 20, 100, true, 0, threads);
677 }
678
679 BENCHMARK(BM_chain_1024_1_false)->Arg(1);
680 BENCHMARK(BM_chain_1024_1_true)->Arg(1);
681 BENCHMARK(BM_chain_1024_1_false)->Arg(2);
682 BENCHMARK(BM_chain_1024_1_true)->Arg(2);
683 BENCHMARK(BM_chain_1024_1_false)->Arg(8);
684 BENCHMARK(BM_chain_1024_1_true)->Arg(8);
685 BENCHMARK(BM_chain_1024_10_false)->Arg(1);
686 BENCHMARK(BM_chain_1024_10_true)->Arg(1);
687 BENCHMARK(BM_chain_1024_10_false)->Arg(8);
688 BENCHMARK(BM_chain_1024_10_true)->Arg(8);
689 BENCHMARK(BM_chain_1024_100_false)->Arg(1);
690 BENCHMARK(BM_chain_1024_100_true)->Arg(1);
691 BENCHMARK(BM_chain_1024_100_false)->Arg(2);
692 BENCHMARK(BM_chain_1024_100_true)->Arg(2);
693 BENCHMARK(BM_chain_1024_100_false)->Arg(8);
694 BENCHMARK(BM_chain_1024_100_true)->Arg(8);
695
696 BENCHMARK(BM_chain_1M_1_false)->Arg(1);
697 BENCHMARK(BM_chain_1M_1_true)->Arg(1);
698 BENCHMARK(BM_chain_1M_1_false)->Arg(2);
699 BENCHMARK(BM_chain_1M_1_true)->Arg(2);
700 BENCHMARK(BM_chain_1M_1_false)->Arg(8);
701 BENCHMARK(BM_chain_1M_1_true)->Arg(8);
702 BENCHMARK(BM_chain_1M_10_false)->Arg(1);
703 BENCHMARK(BM_chain_1M_10_true)->Arg(1);
704 BENCHMARK(BM_chain_1M_10_false)->Arg(8);
705 BENCHMARK(BM_chain_1M_10_true)->Arg(8);
706 BENCHMARK(BM_chain_1M_100_false)->Arg(1);
707 BENCHMARK(BM_chain_1M_100_true)->Arg(1);
708 BENCHMARK(BM_chain_1M_100_false)->Arg(2);
709 BENCHMARK(BM_chain_1M_100_true)->Arg(2);
710 BENCHMARK(BM_chain_1M_100_false)->Arg(8);
711 BENCHMARK(BM_chain_1M_100_true)->Arg(8);
712 #endif
713 } // namespace
714 } // namespace tensorflow
715
716 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
717