1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/test/sequenced_task_runner_test_template.h"
6 
7 #include <ostream>
8 
9 #include "base/location.h"
10 
11 namespace base {
12 
13 namespace internal {
14 
TaskEvent(int i,Type type)15 TaskEvent::TaskEvent(int i, Type type)
16   : i(i), type(type) {
17 }
18 
SequencedTaskTracker()19 SequencedTaskTracker::SequencedTaskTracker()
20     : next_post_i_(0),
21       task_end_count_(0),
22       task_end_cv_(&lock_) {
23 }
24 
PostWrappedNonNestableTask(SequencedTaskRunner * task_runner,const Closure & task)25 void SequencedTaskTracker::PostWrappedNonNestableTask(
26     SequencedTaskRunner* task_runner,
27     const Closure& task) {
28   AutoLock event_lock(lock_);
29   const int post_i = next_post_i_++;
30   Closure wrapped_task = Bind(&SequencedTaskTracker::RunTask, this,
31                               task, post_i);
32   task_runner->PostNonNestableTask(FROM_HERE, wrapped_task);
33   TaskPosted(post_i);
34 }
35 
PostWrappedNestableTask(SequencedTaskRunner * task_runner,const Closure & task)36 void SequencedTaskTracker::PostWrappedNestableTask(
37     SequencedTaskRunner* task_runner,
38     const Closure& task) {
39   AutoLock event_lock(lock_);
40   const int post_i = next_post_i_++;
41   Closure wrapped_task = Bind(&SequencedTaskTracker::RunTask, this,
42                               task, post_i);
43   task_runner->PostTask(FROM_HERE, wrapped_task);
44   TaskPosted(post_i);
45 }
46 
PostWrappedDelayedNonNestableTask(SequencedTaskRunner * task_runner,const Closure & task,TimeDelta delay)47 void SequencedTaskTracker::PostWrappedDelayedNonNestableTask(
48     SequencedTaskRunner* task_runner,
49     const Closure& task,
50     TimeDelta delay) {
51   AutoLock event_lock(lock_);
52   const int post_i = next_post_i_++;
53   Closure wrapped_task = Bind(&SequencedTaskTracker::RunTask, this,
54                               task, post_i);
55   task_runner->PostNonNestableDelayedTask(FROM_HERE, wrapped_task, delay);
56   TaskPosted(post_i);
57 }
58 
PostNonNestableTasks(SequencedTaskRunner * task_runner,int task_count)59 void SequencedTaskTracker::PostNonNestableTasks(
60     SequencedTaskRunner* task_runner,
61     int task_count) {
62   for (int i = 0; i < task_count; ++i) {
63     PostWrappedNonNestableTask(task_runner, Closure());
64   }
65 }
66 
RunTask(const Closure & task,int task_i)67 void SequencedTaskTracker::RunTask(const Closure& task, int task_i) {
68   TaskStarted(task_i);
69   if (!task.is_null())
70     task.Run();
71   TaskEnded(task_i);
72 }
73 
TaskPosted(int i)74 void SequencedTaskTracker::TaskPosted(int i) {
75   // Caller must own |lock_|.
76   events_.push_back(TaskEvent(i, TaskEvent::POST));
77 }
78 
TaskStarted(int i)79 void SequencedTaskTracker::TaskStarted(int i) {
80   AutoLock lock(lock_);
81   events_.push_back(TaskEvent(i, TaskEvent::START));
82 }
83 
TaskEnded(int i)84 void SequencedTaskTracker::TaskEnded(int i) {
85   AutoLock lock(lock_);
86   events_.push_back(TaskEvent(i, TaskEvent::END));
87   ++task_end_count_;
88   task_end_cv_.Signal();
89 }
90 
91 const std::vector<TaskEvent>&
GetTaskEvents() const92 SequencedTaskTracker::GetTaskEvents() const {
93   return events_;
94 }
95 
WaitForCompletedTasks(int count)96 void SequencedTaskTracker::WaitForCompletedTasks(int count) {
97   AutoLock lock(lock_);
98   while (task_end_count_ < count)
99     task_end_cv_.Wait();
100 }
101 
102 SequencedTaskTracker::~SequencedTaskTracker() = default;
103 
PrintTo(const TaskEvent & event,std::ostream * os)104 void PrintTo(const TaskEvent& event, std::ostream* os) {
105   *os << "(i=" << event.i << ", type=";
106   switch (event.type) {
107     case TaskEvent::POST: *os << "POST"; break;
108     case TaskEvent::START: *os << "START"; break;
109     case TaskEvent::END: *os << "END"; break;
110   }
111   *os << ")";
112 }
113 
114 namespace {
115 
116 // Returns the task ordinals for the task event type |type| in the order that
117 // they were recorded.
GetEventTypeOrder(const std::vector<TaskEvent> & events,TaskEvent::Type type)118 std::vector<int> GetEventTypeOrder(const std::vector<TaskEvent>& events,
119                                    TaskEvent::Type type) {
120   std::vector<int> tasks;
121   std::vector<TaskEvent>::const_iterator event;
122   for (event = events.begin(); event != events.end(); ++event) {
123     if (event->type == type)
124       tasks.push_back(event->i);
125   }
126   return tasks;
127 }
128 
129 // Returns all task events for task |task_i|.
GetEventsForTask(const std::vector<TaskEvent> & events,int task_i)130 std::vector<TaskEvent::Type> GetEventsForTask(
131     const std::vector<TaskEvent>& events,
132     int task_i) {
133   std::vector<TaskEvent::Type> task_event_orders;
134   std::vector<TaskEvent>::const_iterator event;
135   for (event = events.begin(); event != events.end(); ++event) {
136     if (event->i == task_i)
137       task_event_orders.push_back(event->type);
138   }
139   return task_event_orders;
140 }
141 
142 // Checks that the task events for each task in |events| occur in the order
143 // {POST, START, END}, and that there is only one instance of each event type
144 // per task.
CheckEventOrdersForEachTask(const std::vector<TaskEvent> & events,int task_count)145 ::testing::AssertionResult CheckEventOrdersForEachTask(
146     const std::vector<TaskEvent>& events,
147     int task_count) {
148   std::vector<TaskEvent::Type> expected_order;
149   expected_order.push_back(TaskEvent::POST);
150   expected_order.push_back(TaskEvent::START);
151   expected_order.push_back(TaskEvent::END);
152 
153   // This is O(n^2), but it runs fast enough currently so is not worth
154   // optimizing.
155   for (int i = 0; i < task_count; ++i) {
156     const std::vector<TaskEvent::Type> task_events =
157         GetEventsForTask(events, i);
158     if (task_events != expected_order) {
159       return ::testing::AssertionFailure()
160           << "Events for task " << i << " are out of order; expected: "
161           << ::testing::PrintToString(expected_order) << "; actual: "
162           << ::testing::PrintToString(task_events);
163     }
164   }
165   return ::testing::AssertionSuccess();
166 }
167 
168 // Checks that no two tasks were running at the same time. I.e. the only
169 // events allowed between the START and END of a task are the POSTs of other
170 // tasks.
CheckNoTaskRunsOverlap(const std::vector<TaskEvent> & events)171 ::testing::AssertionResult CheckNoTaskRunsOverlap(
172     const std::vector<TaskEvent>& events) {
173   // If > -1, we're currently inside a START, END pair.
174   int current_task_i = -1;
175 
176   std::vector<TaskEvent>::const_iterator event;
177   for (event = events.begin(); event != events.end(); ++event) {
178     bool spurious_event_found = false;
179 
180     if (current_task_i == -1) {  // Not inside a START, END pair.
181       switch (event->type) {
182         case TaskEvent::POST:
183           break;
184         case TaskEvent::START:
185           current_task_i = event->i;
186           break;
187         case TaskEvent::END:
188           spurious_event_found = true;
189           break;
190       }
191 
192     } else {  // Inside a START, END pair.
193       bool interleaved_task_detected = false;
194 
195       switch (event->type) {
196         case TaskEvent::POST:
197           if (event->i == current_task_i)
198             spurious_event_found = true;
199           break;
200         case TaskEvent::START:
201           interleaved_task_detected = true;
202           break;
203         case TaskEvent::END:
204           if (event->i != current_task_i)
205             interleaved_task_detected = true;
206           else
207             current_task_i = -1;
208           break;
209       }
210 
211       if (interleaved_task_detected) {
212         return ::testing::AssertionFailure()
213             << "Found event " << ::testing::PrintToString(*event)
214             << " between START and END events for task " << current_task_i
215             << "; event dump: " << ::testing::PrintToString(events);
216       }
217     }
218 
219     if (spurious_event_found) {
220       const int event_i = event - events.begin();
221       return ::testing::AssertionFailure()
222           << "Spurious event " << ::testing::PrintToString(*event)
223           << " at position " << event_i << "; event dump: "
224           << ::testing::PrintToString(events);
225     }
226   }
227 
228   return ::testing::AssertionSuccess();
229 }
230 
231 }  // namespace
232 
CheckNonNestableInvariants(const std::vector<TaskEvent> & events,int task_count)233 ::testing::AssertionResult CheckNonNestableInvariants(
234     const std::vector<TaskEvent>& events,
235     int task_count) {
236   const std::vector<int> post_order =
237       GetEventTypeOrder(events, TaskEvent::POST);
238   const std::vector<int> start_order =
239       GetEventTypeOrder(events, TaskEvent::START);
240   const std::vector<int> end_order =
241       GetEventTypeOrder(events, TaskEvent::END);
242 
243   if (start_order != post_order) {
244     return ::testing::AssertionFailure()
245         << "Expected START order (which equals actual POST order): \n"
246         << ::testing::PrintToString(post_order)
247         << "\n Actual START order:\n"
248         << ::testing::PrintToString(start_order);
249   }
250 
251   if (end_order != post_order) {
252     return ::testing::AssertionFailure()
253         << "Expected END order (which equals actual POST order): \n"
254         << ::testing::PrintToString(post_order)
255         << "\n Actual END order:\n"
256         << ::testing::PrintToString(end_order);
257   }
258 
259   const ::testing::AssertionResult result =
260       CheckEventOrdersForEachTask(events, task_count);
261   if (!result)
262     return result;
263 
264   return CheckNoTaskRunsOverlap(events);
265 }
266 
267 }  // namespace internal
268 
269 }  // namespace base
270