1 // Copyright (C) 2020 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "GrpcGraph.h"
16 
17 #include <cstdlib>
18 
19 #include <android-base/logging.h>
20 #include <grpcpp/grpcpp.h>
21 
22 #include "ClientConfig.pb.h"
23 #include "GrpcGraph.h"
24 #include "InputFrame.h"
25 #include "RunnerComponent.h"
26 #include "prebuilt_interface.h"
27 #include "types/Status.h"
28 
29 namespace android {
30 namespace automotive {
31 namespace computepipe {
32 namespace graph {
33 namespace {
34 constexpr int64_t kRpcDeadlineMilliseconds = 100;
35 
36 template <class ResponseType, class RpcType>
FinishRpcAndGetResult(::grpc::ClientAsyncResponseReader<RpcType> * rpc,::grpc::CompletionQueue * cq,ResponseType * response)37 std::pair<Status, std::string> FinishRpcAndGetResult(
38         ::grpc::ClientAsyncResponseReader<RpcType>* rpc, ::grpc::CompletionQueue* cq,
39         ResponseType* response) {
40     int random_tag = rand();
41     ::grpc::Status grpcStatus;
42     rpc->Finish(response, &grpcStatus, reinterpret_cast<void*>(random_tag));
43     bool ok = false;
44     void* got_tag;
45     if (!cq->Next(&got_tag, &ok)) {
46         LOG(ERROR) << "Unexpected shutdown of the completion queue";
47         return std::pair(Status::FATAL_ERROR, "Unexpected shutdown of the completion queue");
48     }
49 
50     if (!ok) {
51         LOG(ERROR) << "Unable to complete RPC request";
52         return std::pair(Status::FATAL_ERROR, "Unable to complete RPC request");
53     }
54 
55     CHECK_EQ(got_tag, reinterpret_cast<void*>(random_tag));
56     if (!grpcStatus.ok()) {
57         std::string error_message =
58                 std::string("Grpc failed with error: ") + grpcStatus.error_message();
59         LOG(ERROR) << error_message;
60         return std::pair(Status::FATAL_ERROR, std::move(error_message));
61     }
62 
63     return std::pair(Status::SUCCESS, std::string(""));
64 }
65 
66 }  // namespace
67 
GetGraphState() const68 PrebuiltGraphState GrpcGraph::GetGraphState() const {
69     std::lock_guard lock(mLock);
70     return mGraphState;
71 }
72 
GetStatus() const73 Status GrpcGraph::GetStatus() const {
74     std::lock_guard lock(mLock);
75     return mStatus;
76 }
77 
GetErrorMessage() const78 std::string GrpcGraph::GetErrorMessage() const {
79     std::lock_guard lock(mLock);
80     return mErrorMessage;
81 }
82 
initialize(const std::string & address,std::weak_ptr<PrebuiltEngineInterface> engineInterface)83 Status GrpcGraph::initialize(const std::string& address,
84                              std::weak_ptr<PrebuiltEngineInterface> engineInterface) {
85     std::shared_ptr<::grpc::ChannelCredentials> creds = ::grpc::InsecureChannelCredentials();
86     std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateChannel(address, creds);
87     mGraphStub = proto::GrpcGraphService::NewStub(channel);
88     mEngineInterface = engineInterface;
89 
90     ::grpc::ClientContext context;
91     context.set_deadline(std::chrono::system_clock::now() +
92                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
93     ::grpc::CompletionQueue cq;
94 
95     proto::GraphOptionsRequest getGraphOptionsRequest;
96     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::GraphOptionsResponse>> rpc(
97             mGraphStub->AsyncGetGraphOptions(&context, getGraphOptionsRequest, &cq));
98 
99     proto::GraphOptionsResponse response;
100     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
101 
102     if (mStatus != Status::SUCCESS) {
103         LOG(ERROR) << "Failed to get graph options: " << mErrorMessage;
104         return Status::FATAL_ERROR;
105     }
106 
107     std::string serialized_options = response.serialized_options();
108     if (!mGraphConfig.ParseFromString(serialized_options)) {
109         mErrorMessage = "Failed to parse graph options";
110         LOG(ERROR) << "Failed to parse graph options";
111         return Status::FATAL_ERROR;
112     }
113 
114     mGraphState = PrebuiltGraphState::STOPPED;
115     return Status::SUCCESS;
116 }
117 
118 // Function to confirm that there would be no further changes to the graph configuration. This
119 // needs to be called before starting the graph.
handleConfigPhase(const runner::ClientConfig & e)120 Status GrpcGraph::handleConfigPhase(const runner::ClientConfig& e) {
121     std::lock_guard lock(mLock);
122     if (mGraphState == PrebuiltGraphState::UNINITIALIZED) {
123         mStatus = Status::ILLEGAL_STATE;
124         return Status::ILLEGAL_STATE;
125     }
126 
127     // handleConfigPhase is a blocking call, so abort call is pointless for this RunnerEvent.
128     if (e.isAborted()) {
129         mStatus = Status::INVALID_ARGUMENT;
130         return mStatus;
131     } else if (e.isTransitionComplete()) {
132         mStatus = Status::SUCCESS;
133         return mStatus;
134     }
135 
136     ::grpc::ClientContext context;
137     context.set_deadline(std::chrono::system_clock::now() +
138                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
139     ::grpc::CompletionQueue cq;
140 
141     std::string serializedConfig = e.getSerializedClientConfig();
142     proto::SetGraphConfigRequest setGraphConfigRequest;
143     setGraphConfigRequest.set_serialized_config(std::move(serializedConfig));
144 
145     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
146             mGraphStub->AsyncSetGraphConfig(&context, setGraphConfigRequest, &cq));
147 
148     proto::StatusResponse response;
149     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
150     if (mStatus != Status::SUCCESS) {
151         LOG(ERROR) << "Rpc failed while trying to set configuration";
152         return mStatus;
153     }
154 
155     if (response.code() != proto::RemoteGraphStatusCode::SUCCESS) {
156         LOG(ERROR) << "Failed to cofngure remote graph. " << response.message();
157     }
158 
159     mStatus = static_cast<Status>(static_cast<int>(response.code()));
160     mErrorMessage = response.message();
161 
162     mStreamSetObserver = std::make_unique<StreamSetObserver>(e, this);
163 
164     return mStatus;
165 }
166 
167 // Starts the graph.
handleExecutionPhase(const runner::RunnerEvent & e)168 Status GrpcGraph::handleExecutionPhase(const runner::RunnerEvent& e) {
169     std::lock_guard lock(mLock);
170     if (mGraphState != PrebuiltGraphState::STOPPED) {
171         mStatus = Status::ILLEGAL_STATE;
172         return mStatus;
173     }
174 
175     if (e.isAborted()) {
176         // Starting the graph is a blocking call and cannot be aborted in between.
177         mStatus = Status::INVALID_ARGUMENT;
178         return mStatus;
179     } else if (e.isTransitionComplete()) {
180         mStatus = Status::SUCCESS;
181         return mStatus;
182     }
183 
184     // Start observing the output streams
185     mStatus = mStreamSetObserver->startObservingStreams();
186     if (mStatus != Status::SUCCESS) {
187         mErrorMessage = "Failed to observe output streams";
188         return mStatus;
189     }
190 
191     ::grpc::ClientContext context;
192     context.set_deadline(std::chrono::system_clock::now() +
193                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
194 
195     proto::StartGraphExecutionRequest startExecutionRequest;
196     ::grpc::CompletionQueue cq;
197     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
198             mGraphStub->AsyncStartGraphExecution(&context, startExecutionRequest, &cq));
199 
200     proto::StatusResponse response;
201     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
202     if (mStatus != Status::SUCCESS) {
203         LOG(ERROR) << "Failed to start graph execution";
204         return mStatus;
205     }
206 
207     mStatus = static_cast<Status>(static_cast<int>(response.code()));
208     mErrorMessage = response.message();
209 
210     if (mStatus == Status::SUCCESS) {
211         mGraphState = PrebuiltGraphState::RUNNING;
212     }
213 
214     return mStatus;
215 }
216 
217 // Stops the graph while letting the graph flush output packets in flight.
handleStopWithFlushPhase(const runner::RunnerEvent & e)218 Status GrpcGraph::handleStopWithFlushPhase(const runner::RunnerEvent& e) {
219     std::lock_guard lock(mLock);
220     if (mGraphState != PrebuiltGraphState::RUNNING) {
221         return Status::ILLEGAL_STATE;
222     }
223 
224     if (e.isAborted()) {
225         return Status::INVALID_ARGUMENT;
226     } else if (e.isTransitionComplete()) {
227         return Status::SUCCESS;
228     }
229 
230     ::grpc::ClientContext context;
231     context.set_deadline(std::chrono::system_clock::now() +
232                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
233 
234     proto::StopGraphExecutionRequest stopExecutionRequest;
235     stopExecutionRequest.set_stop_immediate(false);
236     ::grpc::CompletionQueue cq;
237     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
238             mGraphStub->AsyncStopGraphExecution(&context, stopExecutionRequest, &cq));
239 
240     proto::StatusResponse response;
241     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
242     if (mStatus != Status::SUCCESS) {
243         LOG(ERROR) << "Failed to stop graph execution";
244         return Status::FATAL_ERROR;
245     }
246 
247     // Stop observing streams immendiately.
248     mStreamSetObserver->stopObservingStreams(false);
249 
250     mStatus = static_cast<Status>(static_cast<int>(response.code()));
251     mErrorMessage = response.message();
252 
253     if (mStatus == Status::SUCCESS) {
254         mGraphState = PrebuiltGraphState::FLUSHING;
255     }
256 
257     return mStatus;
258 }
259 
260 // Stops the graph and cancels all the output packets.
handleStopImmediatePhase(const runner::RunnerEvent & e)261 Status GrpcGraph::handleStopImmediatePhase(const runner::RunnerEvent& e) {
262     std::lock_guard lock(mLock);
263     if (mGraphState != PrebuiltGraphState::RUNNING) {
264         return Status::ILLEGAL_STATE;
265     }
266 
267     if (e.isAborted()) {
268         return Status::INVALID_ARGUMENT;
269     } else if (e.isTransitionComplete()) {
270         return Status::SUCCESS;
271     }
272 
273     ::grpc::ClientContext context;
274     context.set_deadline(std::chrono::system_clock::now() +
275                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
276 
277     proto::StopGraphExecutionRequest stopExecutionRequest;
278     stopExecutionRequest.set_stop_immediate(true);
279     ::grpc::CompletionQueue cq;
280     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
281             mGraphStub->AsyncStopGraphExecution(&context, stopExecutionRequest, &cq));
282 
283     proto::StatusResponse response;
284     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
285     if (mStatus != Status::SUCCESS) {
286         LOG(ERROR) << "Failed to stop graph execution";
287         return Status::FATAL_ERROR;
288     }
289 
290     mStatus = static_cast<Status>(static_cast<int>(response.code()));
291     mErrorMessage = response.message();
292 
293     // Stop observing streams immendiately.
294     mStreamSetObserver->stopObservingStreams(true);
295 
296     if (mStatus == Status::SUCCESS) {
297         mGraphState = PrebuiltGraphState::STOPPED;
298     }
299     return mStatus;
300 }
301 
handleResetPhase(const runner::RunnerEvent & e)302 Status GrpcGraph::handleResetPhase(const runner::RunnerEvent& e) {
303     std::lock_guard lock(mLock);
304     if (mGraphState != PrebuiltGraphState::STOPPED) {
305         return Status::ILLEGAL_STATE;
306     }
307 
308     if (e.isAborted()) {
309         return Status::INVALID_ARGUMENT;
310     } else if (e.isTransitionComplete()) {
311         return Status::SUCCESS;
312     }
313 
314     ::grpc::ClientContext context;
315     context.set_deadline(std::chrono::system_clock::now() +
316                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
317 
318     proto::ResetGraphRequest resetGraphRequest;
319     ::grpc::CompletionQueue cq;
320     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
321             mGraphStub->AsyncResetGraph(&context, resetGraphRequest, &cq));
322 
323     proto::StatusResponse response;
324     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
325     if (mStatus != Status::SUCCESS) {
326         LOG(ERROR) << "Failed to stop graph execution";
327         return Status::FATAL_ERROR;
328     }
329 
330     mStatus = static_cast<Status>(static_cast<int>(response.code()));
331     mErrorMessage = response.message();
332     mStreamSetObserver.reset();
333 
334     return mStatus;
335 }
336 
SetInputStreamData(int,int64_t,const std::string &)337 Status GrpcGraph::SetInputStreamData(int /*streamIndex*/, int64_t /*timestamp*/,
338                                      const std::string& /*streamData*/) {
339     LOG(ERROR) << "Cannot set input stream for remote graphs";
340     return Status::FATAL_ERROR;
341 }
342 
SetInputStreamPixelData(int,int64_t,const runner::InputFrame &)343 Status GrpcGraph::SetInputStreamPixelData(int /*streamIndex*/, int64_t /*timestamp*/,
344                                           const runner::InputFrame& /*inputFrame*/) {
345     LOG(ERROR) << "Cannot set input streams for remote graphs";
346     return Status::FATAL_ERROR;
347 }
348 
StartGraphProfiling()349 Status GrpcGraph::StartGraphProfiling() {
350     std::lock_guard lock(mLock);
351     if (mGraphState != PrebuiltGraphState::RUNNING) {
352         return Status::ILLEGAL_STATE;
353     }
354 
355     ::grpc::ClientContext context;
356     context.set_deadline(std::chrono::system_clock::now() +
357                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
358 
359     proto::StartGraphProfilingRequest startProfilingRequest;
360     ::grpc::CompletionQueue cq;
361     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
362             mGraphStub->AsyncStartGraphProfiling(&context, startProfilingRequest, &cq));
363 
364     proto::StatusResponse response;
365     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
366     if (mStatus != Status::SUCCESS) {
367         LOG(ERROR) << "Failed to start graph profiling";
368         return Status::FATAL_ERROR;
369     }
370 
371     mStatus = static_cast<Status>(static_cast<int>(response.code()));
372     mErrorMessage = response.message();
373 
374     return mStatus;
375 }
376 
StopGraphProfiling()377 Status GrpcGraph::StopGraphProfiling() {
378     // Stopping profiling after graph has already stopped can be a no-op
379     ::grpc::ClientContext context;
380     context.set_deadline(std::chrono::system_clock::now() +
381                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
382 
383     proto::StopGraphProfilingRequest stopProfilingRequest;
384     ::grpc::CompletionQueue cq;
385     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
386             mGraphStub->AsyncStopGraphProfiling(&context, stopProfilingRequest, &cq));
387 
388     proto::StatusResponse response;
389     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
390     if (mStatus != Status::SUCCESS) {
391         LOG(ERROR) << "Failed to stop graph profiling";
392         return Status::FATAL_ERROR;
393     }
394 
395     mStatus = static_cast<Status>(static_cast<int>(response.code()));
396     mErrorMessage = response.message();
397 
398     return mStatus;
399 }
400 
GetDebugInfo()401 std::string GrpcGraph::GetDebugInfo() {
402     ::grpc::ClientContext context;
403     context.set_deadline(std::chrono::system_clock::now() +
404                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
405 
406     proto::ProfilingDataRequest profilingDataRequest;
407     ::grpc::CompletionQueue cq;
408     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::ProfilingDataResponse>> rpc(
409             mGraphStub->AsyncGetProfilingData(&context, profilingDataRequest, &cq));
410 
411     proto::ProfilingDataResponse response;
412     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
413     if (mStatus != Status::SUCCESS) {
414         LOG(ERROR) << "Failed to get profiling info";
415         return "";
416     }
417 
418     return response.data();
419 }
420 
dispatchPixelData(int streamId,int64_t timestamp_us,const runner::InputFrame & frame)421 void GrpcGraph::dispatchPixelData(int streamId, int64_t timestamp_us,
422                                   const runner::InputFrame& frame) {
423     std::shared_ptr<PrebuiltEngineInterface> engineInterface = mEngineInterface.lock();
424     if (engineInterface) {
425         return engineInterface->DispatchPixelData(streamId, timestamp_us, frame);
426     }
427 }
428 
dispatchSerializedData(int streamId,int64_t timestamp_us,std::string && serialized_data)429 void GrpcGraph::dispatchSerializedData(int streamId, int64_t timestamp_us,
430                                        std::string&& serialized_data) {
431     std::shared_ptr<PrebuiltEngineInterface> engineInterface = mEngineInterface.lock();
432     if (engineInterface) {
433         return engineInterface->DispatchSerializedData(streamId, timestamp_us,
434                                                        std::move(serialized_data));
435     }
436 }
437 
dispatchGraphTerminationMessage(Status status,std::string && errorMessage)438 void GrpcGraph::dispatchGraphTerminationMessage(Status status, std::string&& errorMessage) {
439     std::lock_guard lock(mLock);
440     mErrorMessage = std::move(errorMessage);
441     mStatus = status;
442     mGraphState = PrebuiltGraphState::STOPPED;
443     std::shared_ptr<PrebuiltEngineInterface> engineInterface = mEngineInterface.lock();
444     if (engineInterface) {
445         std::string errorMessageTmp = mErrorMessage;
446         engineInterface->DispatchGraphTerminationMessage(mStatus, std::move(errorMessageTmp));
447     }
448 }
449 
GetRemoteGraphFromAddress(const std::string & address,std::weak_ptr<PrebuiltEngineInterface> engineInterface)450 std::unique_ptr<PrebuiltGraph> GetRemoteGraphFromAddress(
451         const std::string& address, std::weak_ptr<PrebuiltEngineInterface> engineInterface) {
452     auto prebuiltGraph = std::make_unique<GrpcGraph>();
453     Status status = prebuiltGraph->initialize(address, engineInterface);
454     if (status != Status::SUCCESS) {
455         return nullptr;
456     }
457 
458     return prebuiltGraph;
459 }
460 
461 }  // namespace graph
462 }  // namespace computepipe
463 }  // namespace automotive
464 }  // namespace android
465