1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17 
18 namespace tensorflow {
19 
EagerNode(tensorflow::uint64 id)20 EagerNode::EagerNode(tensorflow::uint64 id) : id(id) {}
21 
~EagerExecutor()22 EagerExecutor::~EagerExecutor() {
23   tensorflow::mutex_lock l(node_queue_mutex_);
24   thread_done_ = true;
25   nodes_pending_.notify_all();
26 }
27 
NextId()28 tensorflow::uint64 EagerExecutor::NextId() {
29   tensorflow::mutex_lock l(next_id_mutex_);
30   return next_id_++;
31 }
32 
EnableAsync()33 void EagerExecutor::EnableAsync() {
34   tensorflow::mutex_lock l(node_queue_mutex_);
35   if (thread_ == nullptr) {
36     thread_.reset(tensorflow::Env::Default()->StartThread(
37         tensorflow::ThreadOptions(), "eager_async_executor",
38         std::bind(&EagerExecutor::Run, this)));
39   }
40 }
41 
Add(EagerNode * node)42 void EagerExecutor::Add(EagerNode* node) {
43   tensorflow::mutex_lock l(node_queue_mutex_);
44   DCHECK(thread_) << "EnableAsync should have been called before Add";
45   if (!status_.ok()) {
46     delete node;
47     return;
48   }
49   int64 qlen = node_queue_.size();
50   if (qlen > 0) {
51     if (node_queue_.back()->id >= node->id) {
52       status_ = tensorflow::errors::InvalidArgument(
53           "Inserting EagerNode with non-increasing ids:",
54           node_queue_.back()->id, " vs ", node->id);
55       delete node;
56       return;
57     }
58     node_queue_.push(node);
59   } else {
60     node_queue_.push(node);
61     nodes_pending_.notify_all();
62   }
63 }
64 
WaitFor(tensorflow::uint64 node_id)65 tensorflow::Status EagerExecutor::WaitFor(tensorflow::uint64 node_id) {
66   return WaitImpl(false, node_id);
67 }
68 
WaitForAllPendingNodes()69 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
70   return WaitImpl(true, 0);
71 }
72 
WaitImpl(bool wait_all,tensorflow::uint64 node_id)73 tensorflow::Status EagerExecutor::WaitImpl(bool wait_all,
74                                            tensorflow::uint64 node_id) {
75   tensorflow::condition_variable cond;
76   tensorflow::mutex_lock l(node_queue_mutex_);
77   // Don't wait if an error is already set.
78   if (!status_.ok()) return status_;
79   if (node_queue_.empty()) return tensorflow::Status::OK();
80   if (wait_all) {
81     node_id = node_queue_.back()->id;
82   } else if (node_id < node_queue_.front()->id) {
83     // Note that we are relying on the ops being dispatched sequentially from
84     // the queue.
85     return tensorflow::Status::OK();
86   }
87   node_done_notifications_.insert(std::make_pair(node_id, &cond));
88   cond.wait(l);
89   // Note that we could be woken up if an error occurs, even though the node has
90   // not actually executed.
91   return status_;
92 }
93 
ClearError()94 void EagerExecutor::ClearError() {
95   tensorflow::mutex_lock l(node_queue_mutex_);
96   if (status_.ok()) return;
97   // If an error was set, node_done_notifications_ and node_queue_ should have
98   // been cleared, and no new entries should have been added since.
99   DCHECK(node_done_notifications_.empty());
100   DCHECK(node_queue_.empty());
101   status_ = tensorflow::Status::OK();
102   nodes_pending_.notify_all();
103 }
104 
status()105 tensorflow::Status EagerExecutor::status() {
106   tensorflow::mutex_lock l(node_queue_mutex_);
107   return status_;
108 }
109 
Run()110 void EagerExecutor::Run() {
111   while (true) {
112     std::unique_ptr<EagerNode> curr_node;
113     {
114       tensorflow::mutex_lock l(node_queue_mutex_);
115       while (node_queue_.empty() || !status_.ok()) {
116         if (thread_done_) return;
117         nodes_pending_.wait(l);
118       }
119       curr_node.reset(node_queue_.front());
120     }
121     tensorflow::Status status = curr_node->Run();
122     const bool ok = status.ok();
123     tensorflow::mutex_lock l(node_queue_mutex_);
124     node_queue_.pop();
125     if (!ok) {
126       status_ = status;
127       // TODO(agarwal): mark all affected handles as corrupted before clearing
128       // this queue.
129       // We remove any pending ops so that we don't try to execute them if
130       // ClearError is called.
131       for (int i = 0; i < node_queue_.size(); ++i) {
132         delete node_queue_.front();
133         node_queue_.pop();
134       }
135     }
136     if (!node_done_notifications_.empty()) {
137       tensorflow::uint64 node_id = curr_node->id;
138       // Note that we notify all waiting threads in case an error has occurred.
139       // These calling threads are responsible for checking status_ before
140       // proceeding.
141       const auto range = ok ? node_done_notifications_.equal_range(node_id)
142                             : make_pair(node_done_notifications_.begin(),
143                                         node_done_notifications_.end());
144       for (auto it = range.first; it != range.second; ++it) {
145         it->second->notify_all();
146       }
147       node_done_notifications_.erase(range.first, range.second);
148     }
149   }
150 }
151 
152 }  // namespace tensorflow
153