1 //===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
11 #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
12 
13 #include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
14 #include "llvm/Support/Error.h"
15 
16 #include <condition_variable>
17 #include <queue>
18 
19 namespace llvm {
20 
21 class QueueChannelError : public ErrorInfo<QueueChannelError> {
22 public:
23   static char ID;
24 };
25 
26 class QueueChannelClosedError
27     : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
28 public:
29   static char ID;
convertToErrorCode()30   std::error_code convertToErrorCode() const override {
31     return inconvertibleErrorCode();
32   }
33 
log(raw_ostream & OS)34   void log(raw_ostream &OS) const override {
35     OS << "Queue closed";
36   }
37 };
38 
39 class Queue : public std::queue<char> {
40 public:
41   using ErrorInjector = std::function<Error()>;
42 
Queue()43   Queue()
44     : ReadError([]() { return Error::success(); }),
45       WriteError([]() { return Error::success(); }) {}
46 
47   Queue(const Queue&) = delete;
48   Queue& operator=(const Queue&) = delete;
49   Queue(Queue&&) = delete;
50   Queue& operator=(Queue&&) = delete;
51 
getMutex()52   std::mutex &getMutex() { return M; }
getCondVar()53   std::condition_variable &getCondVar() { return CV; }
checkReadError()54   Error checkReadError() { return ReadError(); }
checkWriteError()55   Error checkWriteError() { return WriteError(); }
setReadError(ErrorInjector NewReadError)56   void setReadError(ErrorInjector NewReadError) {
57     {
58       std::lock_guard<std::mutex> Lock(M);
59       ReadError = std::move(NewReadError);
60     }
61     CV.notify_one();
62   }
setWriteError(ErrorInjector NewWriteError)63   void setWriteError(ErrorInjector NewWriteError) {
64     std::lock_guard<std::mutex> Lock(M);
65     WriteError = std::move(NewWriteError);
66   }
67 private:
68   std::mutex M;
69   std::condition_variable CV;
70   std::function<Error()> ReadError, WriteError;
71 };
72 
73 class QueueChannel : public orc::rpc::RawByteChannel {
74 public:
QueueChannel(std::shared_ptr<Queue> InQueue,std::shared_ptr<Queue> OutQueue)75   QueueChannel(std::shared_ptr<Queue> InQueue,
76                std::shared_ptr<Queue> OutQueue)
77       : InQueue(InQueue), OutQueue(OutQueue) {}
78 
79   QueueChannel(const QueueChannel&) = delete;
80   QueueChannel& operator=(const QueueChannel&) = delete;
81   QueueChannel(QueueChannel&&) = delete;
82   QueueChannel& operator=(QueueChannel&&) = delete;
83 
readBytes(char * Dst,unsigned Size)84   Error readBytes(char *Dst, unsigned Size) override {
85     std::unique_lock<std::mutex> Lock(InQueue->getMutex());
86     while (Size) {
87       {
88         Error Err = InQueue->checkReadError();
89         while (!Err && InQueue->empty()) {
90           InQueue->getCondVar().wait(Lock);
91           Err = InQueue->checkReadError();
92         }
93         if (Err)
94           return Err;
95       }
96       *Dst++ = InQueue->front();
97       --Size;
98       ++NumRead;
99       InQueue->pop();
100     }
101     return Error::success();
102   }
103 
appendBytes(const char * Src,unsigned Size)104   Error appendBytes(const char *Src, unsigned Size) override {
105     std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
106     while (Size--) {
107       if (Error Err = OutQueue->checkWriteError())
108         return Err;
109       OutQueue->push(*Src++);
110       ++NumWritten;
111     }
112     OutQueue->getCondVar().notify_one();
113     return Error::success();
114   }
115 
send()116   Error send() override { return Error::success(); }
117 
close()118   void close() {
119     auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
120     InQueue->setReadError(ChannelClosed);
121     InQueue->setWriteError(ChannelClosed);
122     OutQueue->setReadError(ChannelClosed);
123     OutQueue->setWriteError(ChannelClosed);
124   }
125 
126   uint64_t NumWritten = 0;
127   uint64_t NumRead = 0;
128 
129 private:
130 
131   std::shared_ptr<Queue> InQueue;
132   std::shared_ptr<Queue> OutQueue;
133 };
134 
135 inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
createPairedQueueChannels()136 createPairedQueueChannels() {
137   auto Q1 = std::make_shared<Queue>();
138   auto Q2 = std::make_shared<Queue>();
139   auto C1 = llvm::make_unique<QueueChannel>(Q1, Q2);
140   auto C2 = llvm::make_unique<QueueChannel>(Q2, Q1);
141   return std::make_pair(std::move(C1), std::move(C2));
142 }
143 
144 }
145 
146 #endif
147