1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
16 
17 #include "tensorflow/core/common_runtime/device.h"
18 #include "tensorflow/core/common_runtime/process_util.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/core/notification.h"
21 
22 namespace tensorflow {
23 
~BufRendezvous()24 BufRendezvous::~BufRendezvous() {
25   mutex_lock l(mu_);
26   if (!hook_table_.empty()) {
27     PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
28                &hook_table_);
29   }
30 }
31 
StartAbort(const Status & s)32 void BufRendezvous::StartAbort(const Status& s) {
33   CHECK(!s.ok());
34   HookTable dummy_table;
35   {
36     mutex_lock l(mu_);
37     status_.Update(s);
38     hook_table_.swap(dummy_table);
39   }
40   PurgeTable(s, &dummy_table);
41 }
42 
PurgeTable(const Status & s,HookTable * table)43 void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
44   for (auto& it : *table) {
45     Hook* h = it.second;
46     if (h->cons_cb != nullptr) {
47       h->cons_cb(s, nullptr);
48     }
49     if (h->prod_cb != nullptr) {
50       h->prod_cb(s);
51     }
52     delete h;
53   }
54   table->clear();
55 }
56 
DebugString() const57 string BufRendezvous::Hook::DebugString() const {
58   return strings::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
59                          ", ctx:", reinterpret_cast<uint64>(prod_ctx),
60                          ", val:", reinterpret_cast<uint64>(prod_value),
61                          ", pcb:", reinterpret_cast<uint64>(&prod_cb),
62                          ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
63 }
64 
ProvideBuf(const string & key,Device * dev,DeviceContext * dev_ctx,const Tensor * v,const AllocatorAttributes & attr,const ProducerCallback & done)65 void BufRendezvous::ProvideBuf(const string& key, Device* dev,
66                                DeviceContext* dev_ctx, const Tensor* v,
67                                const AllocatorAttributes& attr,
68                                const ProducerCallback& done) {
69   Hook* h = nullptr;
70   Status providebuf_status;
71   do {
72     mutex_lock l(mu_);
73     if (!status_.ok()) {
74       providebuf_status = status_;
75       break;
76     } else {
77       auto it = hook_table_.find(key);
78       if (it == hook_table_.end()) {
79         h = new Hook;
80         it = hook_table_.insert(std::make_pair(key, h)).first;
81       } else {
82         if (it->second->prod_cb != nullptr) {
83           providebuf_status = errors::Internal(
84               "BufRendezvous::ProvideBuf already called for key ", key);
85           break;
86         }
87         h = it->second;
88       }
89       // Populate Hook with all of the prod values.
90       h->prod_dev = dev;
91       h->prod_ctx = dev_ctx;
92       h->prod_value = v;
93       h->prod_attr = attr;
94       h->prod_cb = done;
95       // If consumer is waiting, kick off right away, removing Hook from table.
96       if (h->cons_cb != nullptr) {
97         hook_table_.erase(it);
98       } else {
99         h = nullptr;
100       }
101     }
102   } while (false);
103   if (h) {
104     h->cons_cb(Status::OK(), h);
105   }
106   if (!providebuf_status.ok()) {
107     done(providebuf_status);
108   }
109 }
110 
ConsumeBuf(const string & key,const ConsumerCallback & done)111 void BufRendezvous::ConsumeBuf(const string& key,
112                                const ConsumerCallback& done) {
113   Hook* existing_hook = nullptr;
114   Status consumebuf_status;
115   do {
116     mutex_lock l(mu_);
117     if (!status_.ok()) {
118       consumebuf_status = status_;
119       break;
120     }
121     auto it = hook_table_.find(key);
122     if (it != hook_table_.end()) {
123       // Prepare to consume immediately.
124       if (it->second->cons_cb) {
125         consumebuf_status =
126             errors::Internal("Second consumer arrived for key ", key);
127         break;
128       }
129       existing_hook = it->second;
130       hook_table_.erase(it);
131       existing_hook->cons_cb = done;
132     } else {
133       // Hang consumer callback on the Hook.
134       Hook* h = new Hook;
135       hook_table_[key] = h;
136       h->cons_cb = done;
137       return;
138     }
139   } while (false);
140   if (existing_hook) {
141     existing_hook->cons_cb(Status::OK(), existing_hook);
142     return;
143   }
144   if (!consumebuf_status.ok()) {
145     done(consumebuf_status, nullptr);
146     return;
147   }
148 }
149 
150 /*static*/
DoneWithHook(Hook * h)151 void BufRendezvous::DoneWithHook(Hook* h) {
152   h->prod_cb(Status::OK());
153   delete h;
154 }
155 
LogContents()156 void BufRendezvous::LogContents() {
157   mutex_lock l(mu_);
158   LOG(INFO) << strings::StrCat("BufRendezvous ",
159                                strings::Hex(reinterpret_cast<uint64>(this)),
160                                " step_id=", step_id_, " current contents:");
161   for (auto it : hook_table_) {
162     LOG(INFO) << it.first << ":" << it.second->DebugString();
163   }
164 }
165 
166 }  // namespace tensorflow
167