1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/streams/stream.h>
6 
7 #include <algorithm>
8 
9 #include <base/bind.h>
10 #include <brillo/message_loops/message_loop.h>
11 #include <brillo/pointer_utils.h>
12 #include <brillo/streams/stream_errors.h>
13 #include <brillo/streams/stream_utils.h>
14 
15 namespace brillo {
16 
TruncateBlocking(ErrorPtr * error)17 bool Stream::TruncateBlocking(ErrorPtr* error) {
18   return SetSizeBlocking(GetPosition(), error);
19 }
20 
SetPosition(uint64_t position,ErrorPtr * error)21 bool Stream::SetPosition(uint64_t position, ErrorPtr* error) {
22   if (!stream_utils::CheckInt64Overflow(FROM_HERE, position, 0, error))
23     return false;
24   return Seek(position, Whence::FROM_BEGIN, nullptr, error);
25 }
26 
ReadAsync(void * buffer,size_t size_to_read,const base::Callback<void (size_t)> & success_callback,const ErrorCallback & error_callback,ErrorPtr * error)27 bool Stream::ReadAsync(void* buffer,
28                        size_t size_to_read,
29                        const base::Callback<void(size_t)>& success_callback,
30                        const ErrorCallback& error_callback,
31                        ErrorPtr* error) {
32   if (is_async_read_pending_) {
33     Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
34                  errors::stream::kOperationNotSupported,
35                  "Another asynchronous operation is still pending");
36     return false;
37   }
38 
39   auto callback = base::Bind(&Stream::IgnoreEOSCallback, success_callback);
40   // If we can read some data right away non-blocking we should still run the
41   // callback from the main loop, so we pass true here for force_async_callback.
42   return ReadAsyncImpl(buffer, size_to_read, callback, error_callback, error,
43                        true);
44 }
45 
ReadAllAsync(void * buffer,size_t size_to_read,const base::Closure & success_callback,const ErrorCallback & error_callback,ErrorPtr * error)46 bool Stream::ReadAllAsync(void* buffer,
47                           size_t size_to_read,
48                           const base::Closure& success_callback,
49                           const ErrorCallback& error_callback,
50                           ErrorPtr* error) {
51   if (is_async_read_pending_) {
52     Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
53                  errors::stream::kOperationNotSupported,
54                  "Another asynchronous operation is still pending");
55     return false;
56   }
57 
58   auto callback = base::Bind(&Stream::ReadAllAsyncCallback,
59                              weak_ptr_factory_.GetWeakPtr(), buffer,
60                              size_to_read, success_callback, error_callback);
61   return ReadAsyncImpl(buffer, size_to_read, callback, error_callback, error,
62                        true);
63 }
64 
ReadBlocking(void * buffer,size_t size_to_read,size_t * size_read,ErrorPtr * error)65 bool Stream::ReadBlocking(void* buffer,
66                           size_t size_to_read,
67                           size_t* size_read,
68                           ErrorPtr* error) {
69   for (;;) {
70     bool eos = false;
71     if (!ReadNonBlocking(buffer, size_to_read, size_read, &eos, error))
72       return false;
73 
74     if (*size_read > 0 || eos)
75       break;
76 
77     if (!WaitForDataBlocking(AccessMode::READ, base::TimeDelta::Max(), nullptr,
78                              error)) {
79       return false;
80     }
81   }
82   return true;
83 }
84 
ReadAllBlocking(void * buffer,size_t size_to_read,ErrorPtr * error)85 bool Stream::ReadAllBlocking(void* buffer,
86                              size_t size_to_read,
87                              ErrorPtr* error) {
88   while (size_to_read > 0) {
89     size_t size_read = 0;
90     if (!ReadBlocking(buffer, size_to_read, &size_read, error))
91       return false;
92 
93     if (size_read == 0)
94       return stream_utils::ErrorReadPastEndOfStream(FROM_HERE, error);
95 
96     size_to_read -= size_read;
97     buffer = AdvancePointer(buffer, size_read);
98   }
99   return true;
100 }
101 
WriteAsync(const void * buffer,size_t size_to_write,const base::Callback<void (size_t)> & success_callback,const ErrorCallback & error_callback,ErrorPtr * error)102 bool Stream::WriteAsync(const void* buffer,
103                         size_t size_to_write,
104                         const base::Callback<void(size_t)>& success_callback,
105                         const ErrorCallback& error_callback,
106                         ErrorPtr* error) {
107   if (is_async_write_pending_) {
108     Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
109                  errors::stream::kOperationNotSupported,
110                  "Another asynchronous operation is still pending");
111     return false;
112   }
113   // If we can read some data right away non-blocking we should still run the
114   // callback from the main loop, so we pass true here for force_async_callback.
115   return WriteAsyncImpl(buffer, size_to_write, success_callback, error_callback,
116                         error, true);
117 }
118 
WriteAllAsync(const void * buffer,size_t size_to_write,const base::Closure & success_callback,const ErrorCallback & error_callback,ErrorPtr * error)119 bool Stream::WriteAllAsync(const void* buffer,
120                            size_t size_to_write,
121                            const base::Closure& success_callback,
122                            const ErrorCallback& error_callback,
123                            ErrorPtr* error) {
124   if (is_async_write_pending_) {
125     Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
126                  errors::stream::kOperationNotSupported,
127                  "Another asynchronous operation is still pending");
128     return false;
129   }
130 
131   auto callback = base::Bind(&Stream::WriteAllAsyncCallback,
132                              weak_ptr_factory_.GetWeakPtr(), buffer,
133                              size_to_write, success_callback, error_callback);
134   return WriteAsyncImpl(buffer, size_to_write, callback, error_callback, error,
135                         true);
136 }
137 
WriteBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)138 bool Stream::WriteBlocking(const void* buffer,
139                            size_t size_to_write,
140                            size_t* size_written,
141                            ErrorPtr* error) {
142   for (;;) {
143     if (!WriteNonBlocking(buffer, size_to_write, size_written, error))
144       return false;
145 
146     if (*size_written > 0 || size_to_write == 0)
147       break;
148 
149     if (!WaitForDataBlocking(AccessMode::WRITE, base::TimeDelta::Max(), nullptr,
150                              error)) {
151       return false;
152     }
153   }
154   return true;
155 }
156 
WriteAllBlocking(const void * buffer,size_t size_to_write,ErrorPtr * error)157 bool Stream::WriteAllBlocking(const void* buffer,
158                               size_t size_to_write,
159                               ErrorPtr* error) {
160   while (size_to_write > 0) {
161     size_t size_written = 0;
162     if (!WriteBlocking(buffer, size_to_write, &size_written, error))
163       return false;
164 
165     if (size_written == 0) {
166       Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
167                    errors::stream::kPartialData,
168                    "Failed to write all the data");
169       return false;
170     }
171     size_to_write -= size_written;
172     buffer = AdvancePointer(buffer, size_written);
173   }
174   return true;
175 }
176 
FlushAsync(const base::Closure & success_callback,const ErrorCallback & error_callback,ErrorPtr *)177 bool Stream::FlushAsync(const base::Closure& success_callback,
178                         const ErrorCallback& error_callback,
179                         ErrorPtr* /* error */) {
180   auto callback = base::Bind(&Stream::FlushAsyncCallback,
181                              weak_ptr_factory_.GetWeakPtr(),
182                              success_callback, error_callback);
183   MessageLoop::current()->PostTask(FROM_HERE, callback);
184   return true;
185 }
186 
IgnoreEOSCallback(const base::Callback<void (size_t)> & success_callback,size_t bytes,bool)187 void Stream::IgnoreEOSCallback(
188     const base::Callback<void(size_t)>& success_callback,
189     size_t bytes,
190     bool /* eos */) {
191   success_callback.Run(bytes);
192 }
193 
ReadAsyncImpl(void * buffer,size_t size_to_read,const base::Callback<void (size_t,bool)> & success_callback,const ErrorCallback & error_callback,ErrorPtr * error,bool force_async_callback)194 bool Stream::ReadAsyncImpl(
195     void* buffer,
196     size_t size_to_read,
197     const base::Callback<void(size_t, bool)>& success_callback,
198     const ErrorCallback& error_callback,
199     ErrorPtr* error,
200     bool force_async_callback) {
201   CHECK(!is_async_read_pending_);
202   // We set this value to true early in the function so calling others will
203   // prevent us from calling WaitForData() to make calls to
204   // ReadAsync() fail while we run WaitForData().
205   is_async_read_pending_ = true;
206 
207   size_t read = 0;
208   bool eos = false;
209   if (!ReadNonBlocking(buffer, size_to_read, &read, &eos, error))
210     return false;
211 
212   if (read > 0 || eos) {
213     if (force_async_callback) {
214       MessageLoop::current()->PostTask(
215           FROM_HERE,
216           base::Bind(&Stream::OnReadAsyncDone, weak_ptr_factory_.GetWeakPtr(),
217                      success_callback, read, eos));
218     } else {
219       is_async_read_pending_ = false;
220       success_callback.Run(read, eos);
221     }
222     return true;
223   }
224 
225   is_async_read_pending_ = WaitForData(
226       AccessMode::READ,
227       base::Bind(&Stream::OnReadAvailable, weak_ptr_factory_.GetWeakPtr(),
228                  buffer, size_to_read, success_callback, error_callback),
229       error);
230   return is_async_read_pending_;
231 }
232 
OnReadAsyncDone(const base::Callback<void (size_t,bool)> & success_callback,size_t bytes_read,bool eos)233 void Stream::OnReadAsyncDone(
234     const base::Callback<void(size_t, bool)>& success_callback,
235     size_t bytes_read,
236     bool eos) {
237   is_async_read_pending_ = false;
238   success_callback.Run(bytes_read, eos);
239 }
240 
OnReadAvailable(void * buffer,size_t size_to_read,const base::Callback<void (size_t,bool)> & success_callback,const ErrorCallback & error_callback,AccessMode mode)241 void Stream::OnReadAvailable(
242     void* buffer,
243     size_t size_to_read,
244     const base::Callback<void(size_t, bool)>& success_callback,
245     const ErrorCallback& error_callback,
246     AccessMode mode) {
247   CHECK(stream_utils::IsReadAccessMode(mode));
248   CHECK(is_async_read_pending_);
249   is_async_read_pending_ = false;
250   ErrorPtr error;
251   // Just reschedule the read operation but don't need to run the callback from
252   // the main loop since we are already running on a callback.
253   if (!ReadAsyncImpl(buffer, size_to_read, success_callback, error_callback,
254                      &error, false)) {
255     error_callback.Run(error.get());
256   }
257 }
258 
WriteAsyncImpl(const void * buffer,size_t size_to_write,const base::Callback<void (size_t)> & success_callback,const ErrorCallback & error_callback,ErrorPtr * error,bool force_async_callback)259 bool Stream::WriteAsyncImpl(
260     const void* buffer,
261     size_t size_to_write,
262     const base::Callback<void(size_t)>& success_callback,
263     const ErrorCallback& error_callback,
264     ErrorPtr* error,
265     bool force_async_callback) {
266   CHECK(!is_async_write_pending_);
267   // We set this value to true early in the function so calling others will
268   // prevent us from calling WaitForData() to make calls to
269   // ReadAsync() fail while we run WaitForData().
270   is_async_write_pending_ = true;
271 
272   size_t written = 0;
273   if (!WriteNonBlocking(buffer, size_to_write, &written, error))
274     return false;
275 
276   if (written > 0) {
277     if (force_async_callback) {
278       MessageLoop::current()->PostTask(
279           FROM_HERE,
280           base::Bind(&Stream::OnWriteAsyncDone, weak_ptr_factory_.GetWeakPtr(),
281                      success_callback, written));
282     } else {
283       is_async_write_pending_ = false;
284       success_callback.Run(written);
285     }
286     return true;
287   }
288   is_async_write_pending_ = WaitForData(
289       AccessMode::WRITE,
290       base::Bind(&Stream::OnWriteAvailable, weak_ptr_factory_.GetWeakPtr(),
291                  buffer, size_to_write, success_callback, error_callback),
292       error);
293   return is_async_write_pending_;
294 }
295 
OnWriteAsyncDone(const base::Callback<void (size_t)> & success_callback,size_t size_written)296 void Stream::OnWriteAsyncDone(
297     const base::Callback<void(size_t)>& success_callback,
298     size_t size_written) {
299   is_async_write_pending_ = false;
300   success_callback.Run(size_written);
301 }
302 
OnWriteAvailable(const void * buffer,size_t size,const base::Callback<void (size_t)> & success_callback,const ErrorCallback & error_callback,AccessMode mode)303 void Stream::OnWriteAvailable(
304     const void* buffer,
305     size_t size,
306     const base::Callback<void(size_t)>& success_callback,
307     const ErrorCallback& error_callback,
308     AccessMode mode) {
309   CHECK(stream_utils::IsWriteAccessMode(mode));
310   CHECK(is_async_write_pending_);
311   is_async_write_pending_ = false;
312   ErrorPtr error;
313   // Just reschedule the read operation but don't need to run the callback from
314   // the main loop since we are already running on a callback.
315   if (!WriteAsyncImpl(buffer, size, success_callback, error_callback, &error,
316                       false)) {
317     error_callback.Run(error.get());
318   }
319 }
320 
ReadAllAsyncCallback(void * buffer,size_t size_to_read,const base::Closure & success_callback,const ErrorCallback & error_callback,size_t size_read,bool eos)321 void Stream::ReadAllAsyncCallback(void* buffer,
322                                   size_t size_to_read,
323                                   const base::Closure& success_callback,
324                                   const ErrorCallback& error_callback,
325                                   size_t size_read,
326                                   bool eos) {
327   ErrorPtr error;
328   size_to_read -= size_read;
329   if (size_to_read != 0 && eos) {
330     stream_utils::ErrorReadPastEndOfStream(FROM_HERE, &error);
331     error_callback.Run(error.get());
332     return;
333   }
334 
335   if (size_to_read) {
336     buffer = AdvancePointer(buffer, size_read);
337     auto callback = base::Bind(&Stream::ReadAllAsyncCallback,
338                                weak_ptr_factory_.GetWeakPtr(), buffer,
339                                size_to_read, success_callback, error_callback);
340     if (!ReadAsyncImpl(buffer, size_to_read, callback, error_callback, &error,
341                        false)) {
342       error_callback.Run(error.get());
343     }
344   } else {
345     success_callback.Run();
346   }
347 }
348 
WriteAllAsyncCallback(const void * buffer,size_t size_to_write,const base::Closure & success_callback,const ErrorCallback & error_callback,size_t size_written)349 void Stream::WriteAllAsyncCallback(const void* buffer,
350                                    size_t size_to_write,
351                                    const base::Closure& success_callback,
352                                    const ErrorCallback& error_callback,
353                                    size_t size_written) {
354   ErrorPtr error;
355   if (size_to_write != 0 && size_written == 0) {
356     Error::AddTo(&error, FROM_HERE, errors::stream::kDomain,
357                  errors::stream::kPartialData, "Failed to write all the data");
358     error_callback.Run(error.get());
359     return;
360   }
361   size_to_write -= size_written;
362   if (size_to_write) {
363     buffer = AdvancePointer(buffer, size_written);
364     auto callback = base::Bind(&Stream::WriteAllAsyncCallback,
365                                weak_ptr_factory_.GetWeakPtr(), buffer,
366                                size_to_write, success_callback, error_callback);
367     if (!WriteAsyncImpl(buffer, size_to_write, callback, error_callback, &error,
368                         false)) {
369       error_callback.Run(error.get());
370     }
371   } else {
372     success_callback.Run();
373   }
374 }
375 
FlushAsyncCallback(const base::Closure & success_callback,const ErrorCallback & error_callback)376 void Stream::FlushAsyncCallback(const base::Closure& success_callback,
377                                 const ErrorCallback& error_callback) {
378   ErrorPtr error;
379   if (FlushBlocking(&error)) {
380     success_callback.Run();
381   } else {
382     error_callback.Run(error.get());
383   }
384 }
385 
CancelPendingAsyncOperations()386 void Stream::CancelPendingAsyncOperations() {
387   weak_ptr_factory_.InvalidateWeakPtrs();
388   is_async_read_pending_ = false;
389   is_async_write_pending_ = false;
390 }
391 
392 }  // namespace brillo
393