1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/streams/stream_utils.h>
6 
7 #include <limits>
8 
9 #include <base/bind.h>
10 #include <brillo/message_loops/fake_message_loop.h>
11 #include <brillo/message_loops/message_loop.h>
12 #include <brillo/streams/mock_stream.h>
13 #include <brillo/streams/stream_errors.h>
14 #include <gmock/gmock.h>
15 #include <gtest/gtest.h>
16 
17 using testing::DoAll;
18 using testing::InSequence;
19 using testing::Return;
20 using testing::StrictMock;
21 using testing::_;
22 
ACTION_TEMPLATE(InvokeAsyncCallback,HAS_1_TEMPLATE_PARAMS (int,k),AND_1_VALUE_PARAMS (size))23 ACTION_TEMPLATE(InvokeAsyncCallback,
24                 HAS_1_TEMPLATE_PARAMS(int, k),
25                 AND_1_VALUE_PARAMS(size)) {
26   brillo::MessageLoop::current()->PostTask(
27       FROM_HERE, base::Bind(std::get<k>(args), size));
28   return true;
29 }
30 
ACTION_TEMPLATE(InvokeAsyncCallback,HAS_1_TEMPLATE_PARAMS (int,k),AND_0_VALUE_PARAMS ())31 ACTION_TEMPLATE(InvokeAsyncCallback,
32                 HAS_1_TEMPLATE_PARAMS(int, k),
33                 AND_0_VALUE_PARAMS()) {
34   brillo::MessageLoop::current()->PostTask(FROM_HERE, std::get<k>(args));
35   return true;
36 }
37 
ACTION_TEMPLATE(InvokeAsyncErrorCallback,HAS_1_TEMPLATE_PARAMS (int,k),AND_1_VALUE_PARAMS (code))38 ACTION_TEMPLATE(InvokeAsyncErrorCallback,
39                 HAS_1_TEMPLATE_PARAMS(int, k),
40                 AND_1_VALUE_PARAMS(code)) {
41   brillo::ErrorPtr error;
42   brillo::Error::AddTo(&error, FROM_HERE, "test", code, "message");
43   brillo::MessageLoop::current()->PostTask(
44       FROM_HERE, base::Bind(std::get<k>(args), base::Owned(error.release())));
45   return true;
46 }
47 
48 namespace brillo {
49 
TEST(StreamUtils,ErrorStreamClosed)50 TEST(StreamUtils, ErrorStreamClosed) {
51   ErrorPtr error;
52   EXPECT_FALSE(stream_utils::ErrorStreamClosed(FROM_HERE, &error));
53   EXPECT_EQ(errors::stream::kDomain, error->GetDomain());
54   EXPECT_EQ(errors::stream::kStreamClosed, error->GetCode());
55   EXPECT_EQ("Stream is closed", error->GetMessage());
56 }
57 
TEST(StreamUtils,ErrorOperationNotSupported)58 TEST(StreamUtils, ErrorOperationNotSupported) {
59   ErrorPtr error;
60   EXPECT_FALSE(stream_utils::ErrorOperationNotSupported(FROM_HERE, &error));
61   EXPECT_EQ(errors::stream::kDomain, error->GetDomain());
62   EXPECT_EQ(errors::stream::kOperationNotSupported, error->GetCode());
63   EXPECT_EQ("Stream operation not supported", error->GetMessage());
64 }
65 
TEST(StreamUtils,ErrorReadPastEndOfStream)66 TEST(StreamUtils, ErrorReadPastEndOfStream) {
67   ErrorPtr error;
68   EXPECT_FALSE(stream_utils::ErrorReadPastEndOfStream(FROM_HERE, &error));
69   EXPECT_EQ(errors::stream::kDomain, error->GetDomain());
70   EXPECT_EQ(errors::stream::kPartialData, error->GetCode());
71   EXPECT_EQ("Reading past the end of stream", error->GetMessage());
72 }
73 
TEST(StreamUtils,CheckInt64Overflow)74 TEST(StreamUtils, CheckInt64Overflow) {
75   const int64_t max_int64 = std::numeric_limits<int64_t>::max();
76   const uint64_t max_uint64 = std::numeric_limits<uint64_t>::max();
77   EXPECT_TRUE(stream_utils::CheckInt64Overflow(FROM_HERE, 0, 0, nullptr));
78   EXPECT_TRUE(stream_utils::CheckInt64Overflow(
79       FROM_HERE, 0, max_int64, nullptr));
80   EXPECT_TRUE(stream_utils::CheckInt64Overflow(
81       FROM_HERE, max_int64, 0, nullptr));
82   EXPECT_TRUE(stream_utils::CheckInt64Overflow(FROM_HERE, 100, -90, nullptr));
83   EXPECT_TRUE(stream_utils::CheckInt64Overflow(
84       FROM_HERE, 1000, -1000, nullptr));
85 
86   ErrorPtr error;
87   EXPECT_FALSE(stream_utils::CheckInt64Overflow(FROM_HERE, 100, -101, &error));
88   EXPECT_EQ(errors::stream::kDomain, error->GetDomain());
89   EXPECT_EQ(errors::stream::kInvalidParameter, error->GetCode());
90   EXPECT_EQ("The stream offset value is out of range", error->GetMessage());
91 
92   EXPECT_FALSE(stream_utils::CheckInt64Overflow(
93       FROM_HERE, max_int64, 1, nullptr));
94   EXPECT_FALSE(stream_utils::CheckInt64Overflow(
95       FROM_HERE, max_uint64, 0, nullptr));
96   EXPECT_FALSE(stream_utils::CheckInt64Overflow(
97       FROM_HERE, max_uint64, max_int64, nullptr));
98 }
99 
TEST(StreamUtils,CalculateStreamPosition)100 TEST(StreamUtils, CalculateStreamPosition) {
101   using Whence = Stream::Whence;
102   const uint64_t current_pos = 1234;
103   const uint64_t end_pos = 2000;
104   uint64_t pos = 0;
105 
106   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
107       FROM_HERE, 0, Whence::FROM_BEGIN, current_pos, end_pos, &pos, nullptr));
108   EXPECT_EQ(0u, pos);
109 
110   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
111       FROM_HERE, 0, Whence::FROM_CURRENT, current_pos, end_pos, &pos, nullptr));
112   EXPECT_EQ(current_pos, pos);
113 
114   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
115       FROM_HERE, 0, Whence::FROM_END, current_pos, end_pos, &pos, nullptr));
116   EXPECT_EQ(end_pos, pos);
117 
118   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
119       FROM_HERE, 10, Whence::FROM_BEGIN, current_pos, end_pos, &pos, nullptr));
120   EXPECT_EQ(10u, pos);
121 
122   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
123       FROM_HERE, 10, Whence::FROM_CURRENT, current_pos, end_pos, &pos,
124       nullptr));
125   EXPECT_EQ(current_pos + 10, pos);
126 
127   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
128       FROM_HERE, 10, Whence::FROM_END, current_pos, end_pos, &pos, nullptr));
129   EXPECT_EQ(end_pos + 10, pos);
130 
131   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
132       FROM_HERE, -10, Whence::FROM_CURRENT, current_pos, end_pos, &pos,
133       nullptr));
134   EXPECT_EQ(current_pos - 10, pos);
135 
136   EXPECT_TRUE(stream_utils::CalculateStreamPosition(
137       FROM_HERE, -10, Whence::FROM_END, current_pos, end_pos, &pos, nullptr));
138   EXPECT_EQ(end_pos - 10, pos);
139 
140   ErrorPtr error;
141   EXPECT_FALSE(stream_utils::CalculateStreamPosition(
142       FROM_HERE, -1, Whence::FROM_BEGIN, current_pos, end_pos, &pos, &error));
143   EXPECT_EQ(errors::stream::kInvalidParameter, error->GetCode());
144   EXPECT_EQ("The stream offset value is out of range", error->GetMessage());
145 
146   EXPECT_FALSE(stream_utils::CalculateStreamPosition(
147       FROM_HERE, -1001, Whence::FROM_CURRENT, 1000, end_pos, &pos, nullptr));
148 
149   const uint64_t max_int64 = std::numeric_limits<int64_t>::max();
150   EXPECT_FALSE(stream_utils::CalculateStreamPosition(
151       FROM_HERE, 1, Whence::FROM_CURRENT, max_int64, end_pos, &pos, nullptr));
152 }
153 
154 class CopyStreamDataTest : public testing::Test {
155  public:
SetUp()156   void SetUp() override {
157     fake_loop_.SetAsCurrent();
158     in_stream_.reset(new StrictMock<MockStream>{});
159     out_stream_.reset(new StrictMock<MockStream>{});
160   }
161 
162   FakeMessageLoop fake_loop_{nullptr};
163   std::unique_ptr<StrictMock<MockStream>> in_stream_;
164   std::unique_ptr<StrictMock<MockStream>> out_stream_;
165   bool succeeded_{false};
166   bool failed_{false};
167 
OnSuccess(uint64_t expected,StreamPtr,StreamPtr,uint64_t copied)168   void OnSuccess(uint64_t expected,
169                  StreamPtr /* in_stream */,
170                  StreamPtr /* out_stream */,
171                  uint64_t copied) {
172     EXPECT_EQ(expected, copied);
173     succeeded_ = true;
174   }
175 
OnError(const std::string & expected_error,StreamPtr,StreamPtr,const Error * error)176   void OnError(const std::string& expected_error,
177                StreamPtr /* in_stream */,
178                StreamPtr /* out_stream */,
179                const Error* error) {
180     EXPECT_EQ(expected_error, error->GetCode());
181     failed_ = true;
182   }
183 
ExpectSuccess()184   void ExpectSuccess() {
185     EXPECT_TRUE(succeeded_);
186     EXPECT_FALSE(failed_);
187   }
188 
ExpectFailure()189   void ExpectFailure() {
190     EXPECT_FALSE(succeeded_);
191     EXPECT_TRUE(failed_);
192   }
193 };
194 
TEST_F(CopyStreamDataTest,CopyAllAtOnce)195 TEST_F(CopyStreamDataTest, CopyAllAtOnce) {
196   {
197     InSequence seq;
198     EXPECT_CALL(*in_stream_, ReadAsync(_, 100, _, _, _))
199         .WillOnce(InvokeAsyncCallback<2>(100));
200     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 100, _, _, _))
201         .WillOnce(InvokeAsyncCallback<2>());
202   }
203   stream_utils::CopyData(
204       std::move(in_stream_), std::move(out_stream_), 100, 4096,
205       base::Bind(&CopyStreamDataTest::OnSuccess, base::Unretained(this), 100),
206       base::Bind(&CopyStreamDataTest::OnError, base::Unretained(this), ""));
207   fake_loop_.Run();
208   ExpectSuccess();
209 }
210 
TEST_F(CopyStreamDataTest,CopyInBlocks)211 TEST_F(CopyStreamDataTest, CopyInBlocks) {
212   {
213     InSequence seq;
214     EXPECT_CALL(*in_stream_, ReadAsync(_, 100, _, _, _))
215         .WillOnce(InvokeAsyncCallback<2>(60));
216     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 60, _, _, _))
217         .WillOnce(InvokeAsyncCallback<2>());
218     EXPECT_CALL(*in_stream_, ReadAsync(_, 40, _, _, _))
219         .WillOnce(InvokeAsyncCallback<2>(40));
220     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 40, _, _, _))
221         .WillOnce(InvokeAsyncCallback<2>());
222   }
223   stream_utils::CopyData(
224       std::move(in_stream_), std::move(out_stream_), 100, 4096,
225       base::Bind(&CopyStreamDataTest::OnSuccess, base::Unretained(this), 100),
226       base::Bind(&CopyStreamDataTest::OnError, base::Unretained(this), ""));
227   fake_loop_.Run();
228   ExpectSuccess();
229 }
230 
TEST_F(CopyStreamDataTest,CopyTillEndOfStream)231 TEST_F(CopyStreamDataTest, CopyTillEndOfStream) {
232   {
233     InSequence seq;
234     EXPECT_CALL(*in_stream_, ReadAsync(_, 100, _, _, _))
235         .WillOnce(InvokeAsyncCallback<2>(60));
236     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 60, _, _, _))
237         .WillOnce(InvokeAsyncCallback<2>());
238     EXPECT_CALL(*in_stream_, ReadAsync(_, 40, _, _, _))
239         .WillOnce(InvokeAsyncCallback<2>(0));
240   }
241   stream_utils::CopyData(
242       std::move(in_stream_), std::move(out_stream_), 100, 4096,
243       base::Bind(&CopyStreamDataTest::OnSuccess, base::Unretained(this), 60),
244       base::Bind(&CopyStreamDataTest::OnError, base::Unretained(this), ""));
245   fake_loop_.Run();
246   ExpectSuccess();
247 }
248 
TEST_F(CopyStreamDataTest,CopyInSmallBlocks)249 TEST_F(CopyStreamDataTest, CopyInSmallBlocks) {
250   {
251     InSequence seq;
252     EXPECT_CALL(*in_stream_, ReadAsync(_, 60, _, _, _))
253         .WillOnce(InvokeAsyncCallback<2>(60));
254     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 60, _, _, _))
255         .WillOnce(InvokeAsyncCallback<2>());
256     EXPECT_CALL(*in_stream_, ReadAsync(_, 40, _, _, _))
257         .WillOnce(InvokeAsyncCallback<2>(40));
258     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 40, _, _, _))
259         .WillOnce(InvokeAsyncCallback<2>());
260   }
261   stream_utils::CopyData(
262       std::move(in_stream_), std::move(out_stream_), 100, 60,
263       base::Bind(&CopyStreamDataTest::OnSuccess, base::Unretained(this), 100),
264       base::Bind(&CopyStreamDataTest::OnError, base::Unretained(this), ""));
265   fake_loop_.Run();
266   ExpectSuccess();
267 }
268 
TEST_F(CopyStreamDataTest,ErrorRead)269 TEST_F(CopyStreamDataTest, ErrorRead) {
270   {
271     InSequence seq;
272     EXPECT_CALL(*in_stream_, ReadAsync(_, 60, _, _, _))
273         .WillOnce(InvokeAsyncErrorCallback<3>("read"));
274   }
275   stream_utils::CopyData(
276       std::move(in_stream_), std::move(out_stream_), 100, 60,
277       base::Bind(&CopyStreamDataTest::OnSuccess, base::Unretained(this), 0),
278       base::Bind(&CopyStreamDataTest::OnError, base::Unretained(this), "read"));
279   fake_loop_.Run();
280   ExpectFailure();
281 }
282 
TEST_F(CopyStreamDataTest,ErrorWrite)283 TEST_F(CopyStreamDataTest, ErrorWrite) {
284   {
285     InSequence seq;
286     EXPECT_CALL(*in_stream_, ReadAsync(_, 60, _, _, _))
287         .WillOnce(InvokeAsyncCallback<2>(60));
288     EXPECT_CALL(*out_stream_, WriteAllAsync(_, 60, _, _, _))
289         .WillOnce(InvokeAsyncErrorCallback<3>("write"));
290   }
291   stream_utils::CopyData(
292       std::move(in_stream_), std::move(out_stream_), 100, 60,
293       base::Bind(&CopyStreamDataTest::OnSuccess, base::Unretained(this), 0),
294       base::Bind(&CopyStreamDataTest::OnError, base::Unretained(this),
295                  "write"));
296   fake_loop_.Run();
297   ExpectFailure();
298 }
299 
300 }  // namespace brillo
301