1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <brillo/streams/stream_utils.h>
6
7 #include <limits>
8
9 #include <base/bind.h>
10 #include <brillo/message_loops/message_loop.h>
11 #include <brillo/streams/stream_errors.h>
12
13 namespace brillo {
14 namespace stream_utils {
15
16 namespace {
17
18 // Status of asynchronous CopyData operation.
19 struct CopyDataState {
20 brillo::StreamPtr in_stream;
21 brillo::StreamPtr out_stream;
22 std::vector<uint8_t> buffer;
23 uint64_t remaining_to_copy;
24 uint64_t size_copied;
25 CopyDataSuccessCallback success_callback;
26 CopyDataErrorCallback error_callback;
27 };
28
29 // Async CopyData I/O error callback.
OnCopyDataError(const std::shared_ptr<CopyDataState> & state,const brillo::Error * error)30 void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
31 const brillo::Error* error) {
32 state->error_callback.Run(std::move(state->in_stream),
33 std::move(state->out_stream), error);
34 }
35
36 // Forward declaration.
37 void PerformRead(const std::shared_ptr<CopyDataState>& state);
38
39 // Callback from read operation for CopyData. Writes the read data to the output
40 // stream and invokes PerformRead when done to restart the copy cycle.
PerformWrite(const std::shared_ptr<CopyDataState> & state,size_t size)41 void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
42 if (size == 0) {
43 state->success_callback.Run(std::move(state->in_stream),
44 std::move(state->out_stream),
45 state->size_copied);
46 return;
47 }
48 state->size_copied += size;
49 CHECK_GE(state->remaining_to_copy, size);
50 state->remaining_to_copy -= size;
51
52 brillo::ErrorPtr error;
53 bool success = state->out_stream->WriteAllAsync(
54 state->buffer.data(), size, base::Bind(&PerformRead, state),
55 base::Bind(&OnCopyDataError, state), &error);
56
57 if (!success)
58 OnCopyDataError(state, error.get());
59 }
60
61 // Performs the read part of asynchronous CopyData operation. Reads the data
62 // from input stream and invokes PerformWrite when done to write the data to
63 // the output stream.
PerformRead(const std::shared_ptr<CopyDataState> & state)64 void PerformRead(const std::shared_ptr<CopyDataState>& state) {
65 brillo::ErrorPtr error;
66 const uint64_t buffer_size = state->buffer.size();
67 // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
68 // also not overflow size_t, so the static_cast below is safe.
69 size_t size_to_read =
70 static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
71 if (size_to_read == 0)
72 return PerformWrite(state, 0); // Nothing more to read. Finish operation.
73 bool success = state->in_stream->ReadAsync(
74 state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
75 base::Bind(OnCopyDataError, state), &error);
76
77 if (!success)
78 OnCopyDataError(state, error.get());
79 }
80
81 } // anonymous namespace
82
ErrorStreamClosed(const base::Location & location,ErrorPtr * error)83 bool ErrorStreamClosed(const base::Location& location,
84 ErrorPtr* error) {
85 Error::AddTo(error,
86 location,
87 errors::stream::kDomain,
88 errors::stream::kStreamClosed,
89 "Stream is closed");
90 return false;
91 }
92
ErrorOperationNotSupported(const base::Location & location,ErrorPtr * error)93 bool ErrorOperationNotSupported(const base::Location& location,
94 ErrorPtr* error) {
95 Error::AddTo(error,
96 location,
97 errors::stream::kDomain,
98 errors::stream::kOperationNotSupported,
99 "Stream operation not supported");
100 return false;
101 }
102
ErrorReadPastEndOfStream(const base::Location & location,ErrorPtr * error)103 bool ErrorReadPastEndOfStream(const base::Location& location,
104 ErrorPtr* error) {
105 Error::AddTo(error,
106 location,
107 errors::stream::kDomain,
108 errors::stream::kPartialData,
109 "Reading past the end of stream");
110 return false;
111 }
112
ErrorOperationTimeout(const base::Location & location,ErrorPtr * error)113 bool ErrorOperationTimeout(const base::Location& location,
114 ErrorPtr* error) {
115 Error::AddTo(error,
116 location,
117 errors::stream::kDomain,
118 errors::stream::kTimeout,
119 "Operation timed out");
120 return false;
121 }
122
CheckInt64Overflow(const base::Location & location,uint64_t position,int64_t offset,ErrorPtr * error)123 bool CheckInt64Overflow(const base::Location& location,
124 uint64_t position,
125 int64_t offset,
126 ErrorPtr* error) {
127 if (offset < 0) {
128 // Subtracting the offset. Make sure we do not underflow.
129 uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
130 if (position >= unsigned_offset)
131 return true;
132 } else {
133 // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
134 if (position <= std::numeric_limits<uint64_t>::max() - offset) {
135 // We definitely will not overflow the unsigned 64 bit integer.
136 // Now check that we end up within the limits of signed 64 bit integer.
137 uint64_t new_position = position + offset;
138 uint64_t max = std::numeric_limits<int64_t>::max();
139 if (new_position <= max)
140 return true;
141 }
142 }
143 Error::AddTo(error,
144 location,
145 errors::stream::kDomain,
146 errors::stream::kInvalidParameter,
147 "The stream offset value is out of range");
148 return false;
149 }
150
CalculateStreamPosition(const base::Location & location,int64_t offset,Stream::Whence whence,uint64_t current_position,uint64_t stream_size,uint64_t * new_position,ErrorPtr * error)151 bool CalculateStreamPosition(const base::Location& location,
152 int64_t offset,
153 Stream::Whence whence,
154 uint64_t current_position,
155 uint64_t stream_size,
156 uint64_t* new_position,
157 ErrorPtr* error) {
158 uint64_t pos = 0;
159 switch (whence) {
160 case Stream::Whence::FROM_BEGIN:
161 pos = 0;
162 break;
163
164 case Stream::Whence::FROM_CURRENT:
165 pos = current_position;
166 break;
167
168 case Stream::Whence::FROM_END:
169 pos = stream_size;
170 break;
171
172 default:
173 Error::AddTo(error,
174 location,
175 errors::stream::kDomain,
176 errors::stream::kInvalidParameter,
177 "Invalid stream position whence");
178 return false;
179 }
180
181 if (!CheckInt64Overflow(location, pos, offset, error))
182 return false;
183
184 *new_position = static_cast<uint64_t>(pos + offset);
185 return true;
186 }
187
CopyData(StreamPtr in_stream,StreamPtr out_stream,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)188 void CopyData(StreamPtr in_stream,
189 StreamPtr out_stream,
190 const CopyDataSuccessCallback& success_callback,
191 const CopyDataErrorCallback& error_callback) {
192 CopyData(std::move(in_stream), std::move(out_stream),
193 std::numeric_limits<uint64_t>::max(), 4096, success_callback,
194 error_callback);
195 }
196
CopyData(StreamPtr in_stream,StreamPtr out_stream,uint64_t max_size_to_copy,size_t buffer_size,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)197 void CopyData(StreamPtr in_stream,
198 StreamPtr out_stream,
199 uint64_t max_size_to_copy,
200 size_t buffer_size,
201 const CopyDataSuccessCallback& success_callback,
202 const CopyDataErrorCallback& error_callback) {
203 auto state = std::make_shared<CopyDataState>();
204 state->in_stream = std::move(in_stream);
205 state->out_stream = std::move(out_stream);
206 state->buffer.resize(buffer_size);
207 state->remaining_to_copy = max_size_to_copy;
208 state->size_copied = 0;
209 state->success_callback = success_callback;
210 state->error_callback = error_callback;
211 brillo::MessageLoop::current()->PostTask(FROM_HERE,
212 base::Bind(&PerformRead, state));
213 }
214
215 } // namespace stream_utils
216 } // namespace brillo
217