1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/streams/stream_utils.h>
6 
7 #include <limits>
8 
9 #include <base/bind.h>
10 #include <brillo/message_loops/message_loop.h>
11 #include <brillo/streams/stream_errors.h>
12 
13 namespace brillo {
14 namespace stream_utils {
15 
16 namespace {
17 
18 // Status of asynchronous CopyData operation.
19 struct CopyDataState {
20   brillo::StreamPtr in_stream;
21   brillo::StreamPtr out_stream;
22   std::vector<uint8_t> buffer;
23   uint64_t remaining_to_copy;
24   uint64_t size_copied;
25   CopyDataSuccessCallback success_callback;
26   CopyDataErrorCallback error_callback;
27 };
28 
29 // Async CopyData I/O error callback.
OnCopyDataError(const std::shared_ptr<CopyDataState> & state,const brillo::Error * error)30 void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
31                      const brillo::Error* error) {
32   state->error_callback.Run(std::move(state->in_stream),
33                             std::move(state->out_stream), error);
34 }
35 
36 // Forward declaration.
37 void PerformRead(const std::shared_ptr<CopyDataState>& state);
38 
39 // Callback from read operation for CopyData. Writes the read data to the output
40 // stream and invokes PerformRead when done to restart the copy cycle.
PerformWrite(const std::shared_ptr<CopyDataState> & state,size_t size)41 void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
42   if (size == 0) {
43     state->success_callback.Run(std::move(state->in_stream),
44                                 std::move(state->out_stream),
45                                 state->size_copied);
46     return;
47   }
48   state->size_copied += size;
49   CHECK_GE(state->remaining_to_copy, size);
50   state->remaining_to_copy -= size;
51 
52   brillo::ErrorPtr error;
53   bool success = state->out_stream->WriteAllAsync(
54       state->buffer.data(), size, base::Bind(&PerformRead, state),
55       base::Bind(&OnCopyDataError, state), &error);
56 
57   if (!success)
58     OnCopyDataError(state, error.get());
59 }
60 
61 // Performs the read part of asynchronous CopyData operation. Reads the data
62 // from input stream and invokes PerformWrite when done to write the data to
63 // the output stream.
PerformRead(const std::shared_ptr<CopyDataState> & state)64 void PerformRead(const std::shared_ptr<CopyDataState>& state) {
65   brillo::ErrorPtr error;
66   const uint64_t buffer_size = state->buffer.size();
67   // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
68   // also not overflow size_t, so the static_cast below is safe.
69   size_t size_to_read =
70       static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
71   if (size_to_read == 0)
72     return PerformWrite(state, 0);  // Nothing more to read. Finish operation.
73   bool success = state->in_stream->ReadAsync(
74       state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
75       base::Bind(OnCopyDataError, state), &error);
76 
77   if (!success)
78     OnCopyDataError(state, error.get());
79 }
80 
81 }  // anonymous namespace
82 
ErrorStreamClosed(const base::Location & location,ErrorPtr * error)83 bool ErrorStreamClosed(const base::Location& location,
84                        ErrorPtr* error) {
85   Error::AddTo(error,
86                location,
87                errors::stream::kDomain,
88                errors::stream::kStreamClosed,
89                "Stream is closed");
90   return false;
91 }
92 
ErrorOperationNotSupported(const base::Location & location,ErrorPtr * error)93 bool ErrorOperationNotSupported(const base::Location& location,
94                                 ErrorPtr* error) {
95   Error::AddTo(error,
96                location,
97                errors::stream::kDomain,
98                errors::stream::kOperationNotSupported,
99                "Stream operation not supported");
100   return false;
101 }
102 
ErrorReadPastEndOfStream(const base::Location & location,ErrorPtr * error)103 bool ErrorReadPastEndOfStream(const base::Location& location,
104                               ErrorPtr* error) {
105   Error::AddTo(error,
106                location,
107                errors::stream::kDomain,
108                errors::stream::kPartialData,
109                "Reading past the end of stream");
110   return false;
111 }
112 
ErrorOperationTimeout(const base::Location & location,ErrorPtr * error)113 bool ErrorOperationTimeout(const base::Location& location,
114                            ErrorPtr* error) {
115   Error::AddTo(error,
116                location,
117                errors::stream::kDomain,
118                errors::stream::kTimeout,
119                "Operation timed out");
120   return false;
121 }
122 
CheckInt64Overflow(const base::Location & location,uint64_t position,int64_t offset,ErrorPtr * error)123 bool CheckInt64Overflow(const base::Location& location,
124                         uint64_t position,
125                         int64_t offset,
126                         ErrorPtr* error) {
127   if (offset < 0) {
128     // Subtracting the offset. Make sure we do not underflow.
129     uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
130     if (position >= unsigned_offset)
131       return true;
132   } else {
133     // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
134     if (position <= std::numeric_limits<uint64_t>::max() - offset) {
135       // We definitely will not overflow the unsigned 64 bit integer.
136       // Now check that we end up within the limits of signed 64 bit integer.
137       uint64_t new_position = position + offset;
138       uint64_t max = std::numeric_limits<int64_t>::max();
139       if (new_position <= max)
140         return true;
141     }
142   }
143   Error::AddTo(error,
144                location,
145                errors::stream::kDomain,
146                errors::stream::kInvalidParameter,
147                "The stream offset value is out of range");
148   return false;
149 }
150 
CalculateStreamPosition(const base::Location & location,int64_t offset,Stream::Whence whence,uint64_t current_position,uint64_t stream_size,uint64_t * new_position,ErrorPtr * error)151 bool CalculateStreamPosition(const base::Location& location,
152                              int64_t offset,
153                              Stream::Whence whence,
154                              uint64_t current_position,
155                              uint64_t stream_size,
156                              uint64_t* new_position,
157                              ErrorPtr* error) {
158   uint64_t pos = 0;
159   switch (whence) {
160     case Stream::Whence::FROM_BEGIN:
161       pos = 0;
162       break;
163 
164     case Stream::Whence::FROM_CURRENT:
165       pos = current_position;
166       break;
167 
168     case Stream::Whence::FROM_END:
169       pos = stream_size;
170       break;
171 
172     default:
173       Error::AddTo(error,
174                    location,
175                    errors::stream::kDomain,
176                    errors::stream::kInvalidParameter,
177                    "Invalid stream position whence");
178       return false;
179   }
180 
181   if (!CheckInt64Overflow(location, pos, offset, error))
182     return false;
183 
184   *new_position = static_cast<uint64_t>(pos + offset);
185   return true;
186 }
187 
CopyData(StreamPtr in_stream,StreamPtr out_stream,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)188 void CopyData(StreamPtr in_stream,
189               StreamPtr out_stream,
190               const CopyDataSuccessCallback& success_callback,
191               const CopyDataErrorCallback& error_callback) {
192   CopyData(std::move(in_stream), std::move(out_stream),
193            std::numeric_limits<uint64_t>::max(), 4096, success_callback,
194            error_callback);
195 }
196 
CopyData(StreamPtr in_stream,StreamPtr out_stream,uint64_t max_size_to_copy,size_t buffer_size,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)197 void CopyData(StreamPtr in_stream,
198               StreamPtr out_stream,
199               uint64_t max_size_to_copy,
200               size_t buffer_size,
201               const CopyDataSuccessCallback& success_callback,
202               const CopyDataErrorCallback& error_callback) {
203   auto state = std::make_shared<CopyDataState>();
204   state->in_stream = std::move(in_stream);
205   state->out_stream = std::move(out_stream);
206   state->buffer.resize(buffer_size);
207   state->remaining_to_copy = max_size_to_copy;
208   state->size_copied = 0;
209   state->success_callback = success_callback;
210   state->error_callback = error_callback;
211   brillo::MessageLoop::current()->PostTask(FROM_HERE,
212                                              base::Bind(&PerformRead, state));
213 }
214 
215 }  // namespace stream_utils
216 }  // namespace brillo
217