1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/threading/thread_task_runner_handle.h"
6 
7 #include <utility>
8 
9 #include "base/bind.h"
10 #include "base/lazy_instance.h"
11 #include "base/logging.h"
12 #include "base/memory/ptr_util.h"
13 #include "base/threading/sequenced_task_runner_handle.h"
14 #include "base/threading/thread_local.h"
15 
16 namespace base {
17 
18 namespace {
19 
20 base::LazyInstance<base::ThreadLocalPointer<ThreadTaskRunnerHandle>>::Leaky
21     lazy_tls_ptr = LAZY_INSTANCE_INITIALIZER;
22 
23 }  // namespace
24 
25 // static
Get()26 scoped_refptr<SingleThreadTaskRunner> ThreadTaskRunnerHandle::Get() {
27   ThreadTaskRunnerHandle* current = lazy_tls_ptr.Pointer()->Get();
28   DCHECK(current);
29   return current->task_runner_;
30 }
31 
32 // static
IsSet()33 bool ThreadTaskRunnerHandle::IsSet() {
34   return !!lazy_tls_ptr.Pointer()->Get();
35 }
36 
37 // static
OverrideForTesting(scoped_refptr<SingleThreadTaskRunner> overriding_task_runner)38 ScopedClosureRunner ThreadTaskRunnerHandle::OverrideForTesting(
39     scoped_refptr<SingleThreadTaskRunner> overriding_task_runner) {
40   // OverrideForTesting() is not compatible with a SequencedTaskRunnerHandle
41   // being set (but SequencedTaskRunnerHandle::IsSet() includes
42   // ThreadTaskRunnerHandle::IsSet() so that's discounted as the only valid
43   // excuse for it to be true). Sadly this means that tests that merely need a
44   // SequencedTaskRunnerHandle on their main thread can be forced to use a
45   // ThreadTaskRunnerHandle if they're also using test task runners (that
46   // OverrideForTesting() when running their tasks from said main thread). To
47   // solve this: sequence_task_runner_handle.cc and thread_task_runner_handle.cc
48   // would have to be merged into a single impl file and share TLS state. This
49   // was deemed unecessary for now as most tests should use higher level
50   // constructs and not have to instantiate task runner handles on their own.
51   DCHECK(!SequencedTaskRunnerHandle::IsSet() || IsSet());
52 
53   if (!IsSet()) {
54     std::unique_ptr<ThreadTaskRunnerHandle> top_level_ttrh =
55         MakeUnique<ThreadTaskRunnerHandle>(std::move(overriding_task_runner));
56     return ScopedClosureRunner(base::Bind(
57         [](std::unique_ptr<ThreadTaskRunnerHandle> ttrh_to_release) {},
58         base::Passed(&top_level_ttrh)));
59   }
60 
61   ThreadTaskRunnerHandle* ttrh = lazy_tls_ptr.Pointer()->Get();
62   // Swap the two (and below bind |overriding_task_runner|, which is now the
63   // previous one, as the |task_runner_to_restore|).
64   ttrh->task_runner_.swap(overriding_task_runner);
65 
66   return ScopedClosureRunner(base::Bind(
67       [](scoped_refptr<SingleThreadTaskRunner> task_runner_to_restore,
68          SingleThreadTaskRunner* expected_task_runner_before_restore) {
69         ThreadTaskRunnerHandle* ttrh = lazy_tls_ptr.Pointer()->Get();
70 
71         DCHECK_EQ(expected_task_runner_before_restore, ttrh->task_runner_.get())
72             << "Nested overrides must expire their ScopedClosureRunners "
73                "in LIFO order.";
74 
75         ttrh->task_runner_.swap(task_runner_to_restore);
76       },
77       base::Passed(&overriding_task_runner),
78       base::Unretained(ttrh->task_runner_.get())));
79 }
80 
ThreadTaskRunnerHandle(scoped_refptr<SingleThreadTaskRunner> task_runner)81 ThreadTaskRunnerHandle::ThreadTaskRunnerHandle(
82     scoped_refptr<SingleThreadTaskRunner> task_runner)
83     : task_runner_(std::move(task_runner)) {
84   DCHECK(task_runner_->BelongsToCurrentThread());
85   // No SequencedTaskRunnerHandle (which includes ThreadTaskRunnerHandles)
86   // should already be set for this thread.
87   DCHECK(!SequencedTaskRunnerHandle::IsSet());
88   lazy_tls_ptr.Pointer()->Set(this);
89 }
90 
~ThreadTaskRunnerHandle()91 ThreadTaskRunnerHandle::~ThreadTaskRunnerHandle() {
92   DCHECK(task_runner_->BelongsToCurrentThread());
93   DCHECK_EQ(lazy_tls_ptr.Pointer()->Get(), this);
94   lazy_tls_ptr.Pointer()->Set(nullptr);
95 }
96 
97 }  // namespace base
98