1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <binder/SafeInterface.h>
18 
19 #include <binder/IInterface.h>
20 #include <binder/IPCThreadState.h>
21 #include <binder/IServiceManager.h>
22 #include <binder/Parcel.h>
23 #include <binder/Parcelable.h>
24 #include <binder/ProcessState.h>
25 
26 #pragma clang diagnostic push
27 #pragma clang diagnostic ignored "-Weverything"
28 #include <gtest/gtest.h>
29 #pragma clang diagnostic pop
30 
31 #include <utils/Flattenable.h>
32 #include <utils/LightRefBase.h>
33 #include <utils/NativeHandle.h>
34 
35 #include <cutils/native_handle.h>
36 
37 #include <optional>
38 
39 #include <inttypes.h>
40 #include <sys/eventfd.h>
41 #include <sys/prctl.h>
42 
43 using namespace std::chrono_literals; // NOLINT - google-build-using-namespace
44 using android::binder::unique_fd;
45 
46 namespace android {
47 namespace tests {
48 
49 static const String16 kServiceName("SafeInterfaceTest");
50 
51 enum class TestEnum : uint32_t {
52     INVALID = 0,
53     INITIAL = 1,
54     FINAL = 2,
55 };
56 
57 // This class serves two purposes:
58 //   1) It ensures that the implementation doesn't require copying or moving the data (for
59 //      efficiency purposes)
60 //   2) It tests that Parcelables can be passed correctly
61 class NoCopyNoMove : public Parcelable {
62 public:
63     NoCopyNoMove() = default;
NoCopyNoMove(int32_t value)64     explicit NoCopyNoMove(int32_t value) : mValue(value) {}
65     ~NoCopyNoMove() override = default;
66 
67     // Not copyable
68     NoCopyNoMove(const NoCopyNoMove&) = delete;
69     NoCopyNoMove& operator=(const NoCopyNoMove&) = delete;
70 
71     // Not movable
72     NoCopyNoMove(NoCopyNoMove&&) = delete;
73     NoCopyNoMove& operator=(NoCopyNoMove&&) = delete;
74 
75     // Parcelable interface
writeToParcel(Parcel * parcel) const76     status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
readFromParcel(const Parcel * parcel)77     status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
78 
getValue() const79     int32_t getValue() const { return mValue; }
setValue(int32_t value)80     void setValue(int32_t value) { mValue = value; }
81 
82 private:
83     int32_t mValue = 0;
84     __attribute__((unused)) uint8_t mPadding[4] = {}; // Avoids a warning from -Wpadded
85 };
86 
87 struct TestFlattenable : Flattenable<TestFlattenable> {
88     TestFlattenable() = default;
TestFlattenableandroid::tests::TestFlattenable89     explicit TestFlattenable(int32_t v) : value(v) {}
90 
91     // Flattenable protocol
getFlattenedSizeandroid::tests::TestFlattenable92     size_t getFlattenedSize() const { return sizeof(value); }
getFdCountandroid::tests::TestFlattenable93     size_t getFdCount() const { return 0; }
flattenandroid::tests::TestFlattenable94     status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
95         FlattenableUtils::write(buffer, size, value);
96         return NO_ERROR;
97     }
unflattenandroid::tests::TestFlattenable98     status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
99         FlattenableUtils::read(buffer, size, value);
100         return NO_ERROR;
101     }
102 
103     int32_t value = 0;
104 };
105 
106 struct TestLightFlattenable : LightFlattenablePod<TestLightFlattenable> {
107     TestLightFlattenable() = default;
TestLightFlattenableandroid::tests::TestLightFlattenable108     explicit TestLightFlattenable(int32_t v) : value(v) {}
109     int32_t value = 0;
110 };
111 
112 // It seems like this should be able to inherit from TestFlattenable (to avoid duplicating code),
113 // but the SafeInterface logic can't easily be extended to find an indirect Flattenable<T>
114 // base class
115 class TestLightRefBaseFlattenable : public Flattenable<TestLightRefBaseFlattenable>,
116                                     public LightRefBase<TestLightRefBaseFlattenable> {
117 public:
118     TestLightRefBaseFlattenable() = default;
TestLightRefBaseFlattenable(int32_t v)119     explicit TestLightRefBaseFlattenable(int32_t v) : value(v) {}
120 
121     // Flattenable protocol
getFlattenedSize() const122     size_t getFlattenedSize() const { return sizeof(value); }
getFdCount() const123     size_t getFdCount() const { return 0; }
flatten(void * & buffer,size_t & size,int * &,size_t &) const124     status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
125         FlattenableUtils::write(buffer, size, value);
126         return NO_ERROR;
127     }
unflatten(void const * & buffer,size_t & size,int const * &,size_t &)128     status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
129         FlattenableUtils::read(buffer, size, value);
130         return NO_ERROR;
131     }
132 
133     int32_t value = 0;
134 };
135 
136 class TestParcelable : public Parcelable {
137 public:
138     TestParcelable() = default;
TestParcelable(int32_t value)139     explicit TestParcelable(int32_t value) : mValue(value) {}
TestParcelable(const TestParcelable & other)140     TestParcelable(const TestParcelable& other) : TestParcelable(other.mValue) {}
TestParcelable(TestParcelable && other)141     TestParcelable(TestParcelable&& other) : TestParcelable(other.mValue) {}
142 
143     // Parcelable interface
writeToParcel(Parcel * parcel) const144     status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
readFromParcel(const Parcel * parcel)145     status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
146 
getValue() const147     int32_t getValue() const { return mValue; }
setValue(int32_t value)148     void setValue(int32_t value) { mValue = value; }
149 
150 private:
151     int32_t mValue = 0;
152 };
153 
154 class ExitOnDeath : public IBinder::DeathRecipient {
155 public:
156     ~ExitOnDeath() override = default;
157 
binderDied(const wp<IBinder> &)158     void binderDied(const wp<IBinder>& /*who*/) override {
159         ALOG(LOG_INFO, "ExitOnDeath", "Exiting");
160         exit(0);
161     }
162 };
163 
164 // This callback class is used to test both one-way transactions and that sp<IInterface> can be
165 // passed correctly
166 class ICallback : public IInterface {
167 public:
168     DECLARE_META_INTERFACE(Callback)
169 
170     enum class Tag : uint32_t {
171         OnCallback = IBinder::FIRST_CALL_TRANSACTION,
172         Last,
173     };
174 
175     virtual void onCallback(int32_t aPlusOne) = 0;
176 };
177 
178 class BpCallback : public SafeBpInterface<ICallback> {
179 public:
BpCallback(const sp<IBinder> & impl)180     explicit BpCallback(const sp<IBinder>& impl) : SafeBpInterface<ICallback>(impl, getLogTag()) {}
181 
onCallback(int32_t aPlusOne)182     void onCallback(int32_t aPlusOne) override {
183         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
184         return callRemoteAsync<decltype(&ICallback::onCallback)>(Tag::OnCallback, aPlusOne);
185     }
186 
187 private:
getLogTag()188     static constexpr const char* getLogTag() { return "BpCallback"; }
189 };
190 
191 #pragma clang diagnostic push
192 #pragma clang diagnostic ignored "-Wexit-time-destructors"
193 IMPLEMENT_META_INTERFACE(Callback, "android.gfx.tests.ICallback")
194 #pragma clang diagnostic pop
195 
196 class BnCallback : public SafeBnInterface<ICallback> {
197 public:
BnCallback()198     BnCallback() : SafeBnInterface("BnCallback") {}
199 
onTransact(uint32_t code,const Parcel & data,Parcel * reply,uint32_t)200     status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
201                         uint32_t /*flags*/) override {
202         EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
203         EXPECT_LT(code, static_cast<uint32_t>(ICallback::Tag::Last));
204         ICallback::Tag tag = static_cast<ICallback::Tag>(code);
205         switch (tag) {
206             case ICallback::Tag::OnCallback: {
207                 return callLocalAsync(data, reply, &ICallback::onCallback);
208             }
209             case ICallback::Tag::Last:
210                 // Should not be possible because of the asserts at the beginning of the method
211                 [&]() { FAIL(); }();
212                 return UNKNOWN_ERROR;
213         }
214     }
215 };
216 
217 class ISafeInterfaceTest : public IInterface {
218 public:
219     DECLARE_META_INTERFACE(SafeInterfaceTest)
220 
221     enum class Tag : uint32_t {
222         SetDeathToken = IBinder::FIRST_CALL_TRANSACTION,
223         ReturnsNoMemory,
224         LogicalNot,
225         ModifyEnum,
226         IncrementFlattenable,
227         IncrementLightFlattenable,
228         IncrementLightRefBaseFlattenable,
229         IncrementNativeHandle,
230         IncrementNoCopyNoMove,
231         IncrementParcelableVector,
232         DoubleString,
233         CallMeBack,
234         IncrementInt32,
235         IncrementUint32,
236         IncrementInt64,
237         IncrementUint64,
238         IncrementFloat,
239         IncrementTwo,
240         Last,
241     };
242 
243     // This is primarily so that the remote service dies when the test does, but it also serves to
244     // test the handling of sp<IBinder> and non-const methods
245     virtual status_t setDeathToken(const sp<IBinder>& token) = 0;
246 
247     // This is the most basic test since it doesn't require parceling any arguments
248     virtual status_t returnsNoMemory() const = 0;
249 
250     // These are ordered according to their corresponding methods in SafeInterface::ParcelHandler
251     virtual status_t logicalNot(bool a, bool* notA) const = 0;
252     virtual status_t modifyEnum(TestEnum a, TestEnum* b) const = 0;
253     virtual status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const = 0;
254     virtual status_t increment(const TestLightFlattenable& a,
255                                TestLightFlattenable* aPlusOne) const = 0;
256     virtual status_t increment(const sp<TestLightRefBaseFlattenable>& a,
257                                sp<TestLightRefBaseFlattenable>* aPlusOne) const = 0;
258     virtual status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const = 0;
259     virtual status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const = 0;
260     virtual status_t increment(const std::vector<TestParcelable>& a,
261                                std::vector<TestParcelable>* aPlusOne) const = 0;
262     virtual status_t doubleString(const String8& str, String8* doubleStr) const = 0;
263     // As mentioned above, sp<IBinder> is already tested by setDeathToken
264     virtual void callMeBack(const sp<ICallback>& callback, int32_t a) const = 0;
265     virtual status_t increment(int32_t a, int32_t* aPlusOne) const = 0;
266     virtual status_t increment(uint32_t a, uint32_t* aPlusOne) const = 0;
267     virtual status_t increment(int64_t a, int64_t* aPlusOne) const = 0;
268     virtual status_t increment(uint64_t a, uint64_t* aPlusOne) const = 0;
269     virtual status_t increment(float a, float* aPlusOne) const = 0;
270 
271     // This tests that input/output parameter interleaving works correctly
272     virtual status_t increment(int32_t a, int32_t* aPlusOne, int32_t b,
273                                int32_t* bPlusOne) const = 0;
274 };
275 
276 class BpSafeInterfaceTest : public SafeBpInterface<ISafeInterfaceTest> {
277 public:
BpSafeInterfaceTest(const sp<IBinder> & impl)278     explicit BpSafeInterfaceTest(const sp<IBinder>& impl)
279           : SafeBpInterface<ISafeInterfaceTest>(impl, getLogTag()) {}
280 
setDeathToken(const sp<IBinder> & token)281     status_t setDeathToken(const sp<IBinder>& token) override {
282         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
283         return callRemote<decltype(&ISafeInterfaceTest::setDeathToken)>(Tag::SetDeathToken, token);
284     }
returnsNoMemory() const285     status_t returnsNoMemory() const override {
286         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
287         return callRemote<decltype(&ISafeInterfaceTest::returnsNoMemory)>(Tag::ReturnsNoMemory);
288     }
logicalNot(bool a,bool * notA) const289     status_t logicalNot(bool a, bool* notA) const override {
290         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
291         return callRemote<decltype(&ISafeInterfaceTest::logicalNot)>(Tag::LogicalNot, a, notA);
292     }
modifyEnum(TestEnum a,TestEnum * b) const293     status_t modifyEnum(TestEnum a, TestEnum* b) const override {
294         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
295         return callRemote<decltype(&ISafeInterfaceTest::modifyEnum)>(Tag::ModifyEnum, a, b);
296     }
increment(const TestFlattenable & a,TestFlattenable * aPlusOne) const297     status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
298         using Signature =
299                 status_t (ISafeInterfaceTest::*)(const TestFlattenable&, TestFlattenable*) const;
300         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
301         return callRemote<Signature>(Tag::IncrementFlattenable, a, aPlusOne);
302     }
increment(const TestLightFlattenable & a,TestLightFlattenable * aPlusOne) const303     status_t increment(const TestLightFlattenable& a,
304                        TestLightFlattenable* aPlusOne) const override {
305         using Signature = status_t (ISafeInterfaceTest::*)(const TestLightFlattenable&,
306                                                            TestLightFlattenable*) const;
307         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
308         return callRemote<Signature>(Tag::IncrementLightFlattenable, a, aPlusOne);
309     }
increment(const sp<TestLightRefBaseFlattenable> & a,sp<TestLightRefBaseFlattenable> * aPlusOne) const310     status_t increment(const sp<TestLightRefBaseFlattenable>& a,
311                        sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
312         using Signature = status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
313                                                            sp<TestLightRefBaseFlattenable>*) const;
314         return callRemote<Signature>(Tag::IncrementLightRefBaseFlattenable, a, aPlusOne);
315     }
increment(const sp<NativeHandle> & a,sp<NativeHandle> * aPlusOne) const316     status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
317         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
318         using Signature =
319                 status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&, sp<NativeHandle>*) const;
320         return callRemote<Signature>(Tag::IncrementNativeHandle, a, aPlusOne);
321     }
increment(const NoCopyNoMove & a,NoCopyNoMove * aPlusOne) const322     status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
323         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
324         using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
325                                                            NoCopyNoMove* aPlusOne) const;
326         return callRemote<Signature>(Tag::IncrementNoCopyNoMove, a, aPlusOne);
327     }
increment(const std::vector<TestParcelable> & a,std::vector<TestParcelable> * aPlusOne) const328     status_t increment(const std::vector<TestParcelable>& a,
329                        std::vector<TestParcelable>* aPlusOne) const override {
330         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
331         using Signature = status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
332                                                            std::vector<TestParcelable>*);
333         return callRemote<Signature>(Tag::IncrementParcelableVector, a, aPlusOne);
334     }
doubleString(const String8 & str,String8 * doubleStr) const335     status_t doubleString(const String8& str, String8* doubleStr) const override {
336         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
337         return callRemote<decltype(&ISafeInterfaceTest::doubleString)>(Tag::DoubleString, str,
338                                                                        doubleStr);
339     }
callMeBack(const sp<ICallback> & callback,int32_t a) const340     void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
341         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
342         return callRemoteAsync<decltype(&ISafeInterfaceTest::callMeBack)>(Tag::CallMeBack, callback,
343                                                                           a);
344     }
increment(int32_t a,int32_t * aPlusOne) const345     status_t increment(int32_t a, int32_t* aPlusOne) const override {
346         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
347         using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
348         return callRemote<Signature>(Tag::IncrementInt32, a, aPlusOne);
349     }
increment(uint32_t a,uint32_t * aPlusOne) const350     status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
351         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
352         using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
353         return callRemote<Signature>(Tag::IncrementUint32, a, aPlusOne);
354     }
increment(int64_t a,int64_t * aPlusOne) const355     status_t increment(int64_t a, int64_t* aPlusOne) const override {
356         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
357         using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
358         return callRemote<Signature>(Tag::IncrementInt64, a, aPlusOne);
359     }
increment(uint64_t a,uint64_t * aPlusOne) const360     status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
361         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
362         using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
363         return callRemote<Signature>(Tag::IncrementUint64, a, aPlusOne);
364     }
increment(float a,float * aPlusOne) const365     status_t increment(float a, float* aPlusOne) const override {
366         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
367         using Signature = status_t (ISafeInterfaceTest::*)(float, float*) const;
368         return callRemote<Signature>(Tag::IncrementFloat, a, aPlusOne);
369     }
increment(int32_t a,int32_t * aPlusOne,int32_t b,int32_t * bPlusOne) const370     status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
371         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
372         using Signature =
373                 status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t, int32_t*) const;
374         return callRemote<Signature>(Tag::IncrementTwo, a, aPlusOne, b, bPlusOne);
375     }
376 
377 private:
getLogTag()378     static constexpr const char* getLogTag() { return "BpSafeInterfaceTest"; }
379 };
380 
381 #pragma clang diagnostic push
382 #pragma clang diagnostic ignored "-Wexit-time-destructors"
383 IMPLEMENT_META_INTERFACE(SafeInterfaceTest, "android.gfx.tests.ISafeInterfaceTest")
384 
getDeathRecipient()385 static sp<IBinder::DeathRecipient> getDeathRecipient() {
386     static sp<IBinder::DeathRecipient> recipient = new ExitOnDeath;
387     return recipient;
388 }
389 #pragma clang diagnostic pop
390 
391 class BnSafeInterfaceTest : public SafeBnInterface<ISafeInterfaceTest> {
392 public:
BnSafeInterfaceTest()393     BnSafeInterfaceTest() : SafeBnInterface(getLogTag()) {}
394 
setDeathToken(const sp<IBinder> & token)395     status_t setDeathToken(const sp<IBinder>& token) override {
396         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
397         token->linkToDeath(getDeathRecipient());
398         return NO_ERROR;
399     }
returnsNoMemory() const400     status_t returnsNoMemory() const override {
401         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
402         return NO_MEMORY;
403     }
logicalNot(bool a,bool * notA) const404     status_t logicalNot(bool a, bool* notA) const override {
405         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
406         *notA = !a;
407         return NO_ERROR;
408     }
modifyEnum(TestEnum a,TestEnum * b) const409     status_t modifyEnum(TestEnum a, TestEnum* b) const override {
410         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
411         *b = (a == TestEnum::INITIAL) ? TestEnum::FINAL : TestEnum::INVALID;
412         return NO_ERROR;
413     }
increment(const TestFlattenable & a,TestFlattenable * aPlusOne) const414     status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
415         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
416         aPlusOne->value = a.value + 1;
417         return NO_ERROR;
418     }
increment(const TestLightFlattenable & a,TestLightFlattenable * aPlusOne) const419     status_t increment(const TestLightFlattenable& a,
420                        TestLightFlattenable* aPlusOne) const override {
421         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
422         aPlusOne->value = a.value + 1;
423         return NO_ERROR;
424     }
increment(const sp<TestLightRefBaseFlattenable> & a,sp<TestLightRefBaseFlattenable> * aPlusOne) const425     status_t increment(const sp<TestLightRefBaseFlattenable>& a,
426                        sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
427         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
428         *aPlusOne = new TestLightRefBaseFlattenable(a->value + 1);
429         return NO_ERROR;
430     }
increment(const sp<NativeHandle> & a,sp<NativeHandle> * aPlusOne) const431     status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
432         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
433         native_handle* rawHandle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
434         if (rawHandle == nullptr) return NO_MEMORY;
435 
436         // Copy the fd over directly
437         rawHandle->data[0] = dup(a->handle()->data[0]);
438 
439         // Increment the int
440         rawHandle->data[1] = a->handle()->data[1] + 1;
441 
442         // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
443         // the native_handle when it goes out of scope
444         *aPlusOne = NativeHandle::create(rawHandle, true);
445         return NO_ERROR;
446     }
increment(const NoCopyNoMove & a,NoCopyNoMove * aPlusOne) const447     status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
448         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
449         aPlusOne->setValue(a.getValue() + 1);
450         return NO_ERROR;
451     }
increment(const std::vector<TestParcelable> & a,std::vector<TestParcelable> * aPlusOne) const452     status_t increment(const std::vector<TestParcelable>& a,
453                        std::vector<TestParcelable>* aPlusOne) const override {
454         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
455         aPlusOne->resize(a.size());
456         for (size_t i = 0; i < a.size(); ++i) {
457             (*aPlusOne)[i].setValue(a[i].getValue() + 1);
458         }
459         return NO_ERROR;
460     }
doubleString(const String8 & str,String8 * doubleStr) const461     status_t doubleString(const String8& str, String8* doubleStr) const override {
462         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
463         *doubleStr = str + str;
464         return NO_ERROR;
465     }
callMeBack(const sp<ICallback> & callback,int32_t a) const466     void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
467         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
468         callback->onCallback(a + 1);
469     }
increment(int32_t a,int32_t * aPlusOne) const470     status_t increment(int32_t a, int32_t* aPlusOne) const override {
471         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
472         *aPlusOne = a + 1;
473         return NO_ERROR;
474     }
increment(uint32_t a,uint32_t * aPlusOne) const475     status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
476         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
477         *aPlusOne = a + 1;
478         return NO_ERROR;
479     }
increment(int64_t a,int64_t * aPlusOne) const480     status_t increment(int64_t a, int64_t* aPlusOne) const override {
481         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
482         *aPlusOne = a + 1;
483         return NO_ERROR;
484     }
increment(uint64_t a,uint64_t * aPlusOne) const485     status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
486         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
487         *aPlusOne = a + 1;
488         return NO_ERROR;
489     }
increment(float a,float * aPlusOne) const490     status_t increment(float a, float* aPlusOne) const override {
491         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
492         *aPlusOne = a + 1.0f;
493         return NO_ERROR;
494     }
increment(int32_t a,int32_t * aPlusOne,int32_t b,int32_t * bPlusOne) const495     status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
496         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
497         *aPlusOne = a + 1;
498         *bPlusOne = b + 1;
499         return NO_ERROR;
500     }
501 
502     // BnInterface
onTransact(uint32_t code,const Parcel & data,Parcel * reply,uint32_t)503     status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
504                         uint32_t /*flags*/) override {
505         EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
506         EXPECT_LT(code, static_cast<uint32_t>(Tag::Last));
507         ISafeInterfaceTest::Tag tag = static_cast<ISafeInterfaceTest::Tag>(code);
508         switch (tag) {
509             case ISafeInterfaceTest::Tag::SetDeathToken: {
510                 return callLocal(data, reply, &ISafeInterfaceTest::setDeathToken);
511             }
512             case ISafeInterfaceTest::Tag::ReturnsNoMemory: {
513                 return callLocal(data, reply, &ISafeInterfaceTest::returnsNoMemory);
514             }
515             case ISafeInterfaceTest::Tag::LogicalNot: {
516                 return callLocal(data, reply, &ISafeInterfaceTest::logicalNot);
517             }
518             case ISafeInterfaceTest::Tag::ModifyEnum: {
519                 return callLocal(data, reply, &ISafeInterfaceTest::modifyEnum);
520             }
521             case ISafeInterfaceTest::Tag::IncrementFlattenable: {
522                 using Signature = status_t (ISafeInterfaceTest::*)(const TestFlattenable& a,
523                                                                    TestFlattenable* aPlusOne) const;
524                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
525             }
526             case ISafeInterfaceTest::Tag::IncrementLightFlattenable: {
527                 using Signature =
528                         status_t (ISafeInterfaceTest::*)(const TestLightFlattenable& a,
529                                                          TestLightFlattenable* aPlusOne) const;
530                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
531             }
532             case ISafeInterfaceTest::Tag::IncrementLightRefBaseFlattenable: {
533                 using Signature =
534                         status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
535                                                          sp<TestLightRefBaseFlattenable>*) const;
536                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
537             }
538             case ISafeInterfaceTest::Tag::IncrementNativeHandle: {
539                 using Signature = status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&,
540                                                                    sp<NativeHandle>*) const;
541                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
542             }
543             case ISafeInterfaceTest::Tag::IncrementNoCopyNoMove: {
544                 using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
545                                                                    NoCopyNoMove* aPlusOne) const;
546                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
547             }
548             case ISafeInterfaceTest::Tag::IncrementParcelableVector: {
549                 using Signature =
550                         status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
551                                                          std::vector<TestParcelable>*) const;
552                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
553             }
554             case ISafeInterfaceTest::Tag::DoubleString: {
555                 return callLocal(data, reply, &ISafeInterfaceTest::doubleString);
556             }
557             case ISafeInterfaceTest::Tag::CallMeBack: {
558                 return callLocalAsync(data, reply, &ISafeInterfaceTest::callMeBack);
559             }
560             case ISafeInterfaceTest::Tag::IncrementInt32: {
561                 using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
562                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
563             }
564             case ISafeInterfaceTest::Tag::IncrementUint32: {
565                 using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
566                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
567             }
568             case ISafeInterfaceTest::Tag::IncrementInt64: {
569                 using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
570                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
571             }
572             case ISafeInterfaceTest::Tag::IncrementUint64: {
573                 using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
574                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
575             }
576             case ISafeInterfaceTest::Tag::IncrementFloat: {
577                 using Signature = status_t (ISafeInterfaceTest::*)(float, float*) const;
578                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
579             }
580             case ISafeInterfaceTest::Tag::IncrementTwo: {
581                 using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t,
582                                                                    int32_t*) const;
583                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
584             }
585             case ISafeInterfaceTest::Tag::Last:
586                 // Should not be possible because of the asserts at the beginning of the method
587                 [&]() { FAIL(); }();
588                 return UNKNOWN_ERROR;
589         }
590     }
591 
592 private:
getLogTag()593     static constexpr const char* getLogTag() { return "BnSafeInterfaceTest"; }
594 };
595 
596 class SafeInterfaceTest : public ::testing::Test {
597 public:
SafeInterfaceTest()598     SafeInterfaceTest() : mSafeInterfaceTest(getRemoteService()) {
599         ProcessState::self()->startThreadPool();
600     }
601     ~SafeInterfaceTest() override = default;
602 
603 protected:
604     sp<ISafeInterfaceTest> mSafeInterfaceTest;
605 
606 private:
getLogTag()607     static constexpr const char* getLogTag() { return "SafeInterfaceTest"; }
608 
getRemoteService()609     sp<ISafeInterfaceTest> getRemoteService() {
610 #pragma clang diagnostic push
611 #pragma clang diagnostic ignored "-Wdeprecated-declarations"
612         sp<IBinder> binder = defaultServiceManager()->getService(kServiceName);
613 #pragma clang diagnostic pop
614         sp<ISafeInterfaceTest> iface = interface_cast<ISafeInterfaceTest>(binder);
615         EXPECT_TRUE(iface != nullptr);
616 
617         iface->setDeathToken(new BBinder);
618 
619         return iface;
620     }
621 };
622 
TEST_F(SafeInterfaceTest,TestReturnsNoMemory)623 TEST_F(SafeInterfaceTest, TestReturnsNoMemory) {
624     status_t result = mSafeInterfaceTest->returnsNoMemory();
625     ASSERT_EQ(NO_MEMORY, result);
626 }
627 
TEST_F(SafeInterfaceTest,TestLogicalNot)628 TEST_F(SafeInterfaceTest, TestLogicalNot) {
629     const bool a = true;
630     bool notA = true;
631     status_t result = mSafeInterfaceTest->logicalNot(a, &notA);
632     ASSERT_EQ(NO_ERROR, result);
633     ASSERT_EQ(!a, notA);
634     // Test both since we don't want to accidentally catch a default false somewhere
635     const bool b = false;
636     bool notB = false;
637     result = mSafeInterfaceTest->logicalNot(b, &notB);
638     ASSERT_EQ(NO_ERROR, result);
639     ASSERT_EQ(!b, notB);
640 }
641 
TEST_F(SafeInterfaceTest,TestModifyEnum)642 TEST_F(SafeInterfaceTest, TestModifyEnum) {
643     const TestEnum a = TestEnum::INITIAL;
644     TestEnum b = TestEnum::INVALID;
645     status_t result = mSafeInterfaceTest->modifyEnum(a, &b);
646     ASSERT_EQ(NO_ERROR, result);
647     ASSERT_EQ(TestEnum::FINAL, b);
648 }
649 
TEST_F(SafeInterfaceTest,TestIncrementFlattenable)650 TEST_F(SafeInterfaceTest, TestIncrementFlattenable) {
651     const TestFlattenable a{1};
652     TestFlattenable aPlusOne{0};
653     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
654     ASSERT_EQ(NO_ERROR, result);
655     ASSERT_EQ(a.value + 1, aPlusOne.value);
656 }
657 
TEST_F(SafeInterfaceTest,TestIncrementLightFlattenable)658 TEST_F(SafeInterfaceTest, TestIncrementLightFlattenable) {
659     const TestLightFlattenable a{1};
660     TestLightFlattenable aPlusOne{0};
661     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
662     ASSERT_EQ(NO_ERROR, result);
663     ASSERT_EQ(a.value + 1, aPlusOne.value);
664 }
665 
TEST_F(SafeInterfaceTest,TestIncrementLightRefBaseFlattenable)666 TEST_F(SafeInterfaceTest, TestIncrementLightRefBaseFlattenable) {
667     sp<TestLightRefBaseFlattenable> a = new TestLightRefBaseFlattenable{1};
668     sp<TestLightRefBaseFlattenable> aPlusOne;
669     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
670     ASSERT_EQ(NO_ERROR, result);
671     ASSERT_NE(nullptr, aPlusOne.get());
672     ASSERT_EQ(a->value + 1, aPlusOne->value);
673 }
674 
675 namespace { // Anonymous namespace
676 
fdsAreEquivalent(int a,int b)677 bool fdsAreEquivalent(int a, int b) {
678     struct stat statA {};
679     struct stat statB {};
680     if (fstat(a, &statA) != 0) return false;
681     if (fstat(b, &statB) != 0) return false;
682     return (statA.st_dev == statB.st_dev) && (statA.st_ino == statB.st_ino);
683 }
684 
685 } // Anonymous namespace
686 
TEST_F(SafeInterfaceTest,TestIncrementNativeHandle)687 TEST_F(SafeInterfaceTest, TestIncrementNativeHandle) {
688     // Create an fd we can use to send and receive from the remote process
689     unique_fd eventFd{eventfd(0 /*initval*/, 0 /*flags*/)};
690     ASSERT_NE(-1, eventFd);
691 
692     // Determine the maximum number of fds this process can have open
693     struct rlimit limit {};
694     ASSERT_EQ(0, getrlimit(RLIMIT_NOFILE, &limit));
695     uint64_t maxFds = limit.rlim_cur;
696 
697     ALOG(LOG_INFO, "SafeInterfaceTest", "%s max FDs: %" PRIu64, __PRETTY_FUNCTION__, maxFds);
698 
699     // Perform this test enough times to rule out fd leaks
700     for (uint32_t iter = 0; iter < (maxFds + 100); ++iter) {
701         native_handle* handle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
702         ASSERT_NE(nullptr, handle);
703         handle->data[0] = dup(eventFd.get());
704         handle->data[1] = 1;
705 
706         // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
707         // the native_handle when it goes out of scope
708         sp<NativeHandle> a = NativeHandle::create(handle, true);
709 
710         sp<NativeHandle> aPlusOne;
711         status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
712         ASSERT_EQ(NO_ERROR, result);
713         ASSERT_TRUE(fdsAreEquivalent(a->handle()->data[0], aPlusOne->handle()->data[0]));
714         ASSERT_EQ(a->handle()->data[1] + 1, aPlusOne->handle()->data[1]);
715     }
716 }
717 
TEST_F(SafeInterfaceTest,TestIncrementNoCopyNoMove)718 TEST_F(SafeInterfaceTest, TestIncrementNoCopyNoMove) {
719     const NoCopyNoMove a{1};
720     NoCopyNoMove aPlusOne{0};
721     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
722     ASSERT_EQ(NO_ERROR, result);
723     ASSERT_EQ(a.getValue() + 1, aPlusOne.getValue());
724 }
725 
TEST_F(SafeInterfaceTest,TestIncrementParcelableVector)726 TEST_F(SafeInterfaceTest, TestIncrementParcelableVector) {
727     const std::vector<TestParcelable> a{TestParcelable{1}, TestParcelable{2}};
728     std::vector<TestParcelable> aPlusOne;
729     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
730     ASSERT_EQ(NO_ERROR, result);
731     ASSERT_EQ(a.size(), aPlusOne.size());
732     for (size_t i = 0; i < a.size(); ++i) {
733         ASSERT_EQ(a[i].getValue() + 1, aPlusOne[i].getValue());
734     }
735 }
736 
TEST_F(SafeInterfaceTest,TestDoubleString)737 TEST_F(SafeInterfaceTest, TestDoubleString) {
738     const String8 str{"asdf"};
739     String8 doubleStr;
740     status_t result = mSafeInterfaceTest->doubleString(str, &doubleStr);
741     ASSERT_EQ(NO_ERROR, result);
742     ASSERT_TRUE(doubleStr == String8{"asdfasdf"});
743 }
744 
TEST_F(SafeInterfaceTest,TestCallMeBack)745 TEST_F(SafeInterfaceTest, TestCallMeBack) {
746     class CallbackReceiver : public BnCallback {
747     public:
748         void onCallback(int32_t aPlusOne) override {
749             ALOG(LOG_INFO, "CallbackReceiver", "%s", __PRETTY_FUNCTION__);
750             std::unique_lock<decltype(mMutex)> lock(mMutex);
751             mValue = aPlusOne;
752             mCondition.notify_one();
753         }
754 
755         std::optional<int32_t> waitForCallback() {
756             std::unique_lock<decltype(mMutex)> lock(mMutex);
757             bool success =
758                     mCondition.wait_for(lock, 100ms, [&]() { return static_cast<bool>(mValue); });
759             return success ? mValue : std::nullopt;
760         }
761 
762     private:
763         std::mutex mMutex;
764         std::condition_variable mCondition;
765         std::optional<int32_t> mValue;
766     };
767 
768     sp<CallbackReceiver> receiver = new CallbackReceiver;
769     const int32_t a = 1;
770     mSafeInterfaceTest->callMeBack(receiver, a);
771     auto result = receiver->waitForCallback();
772     ASSERT_TRUE(result);
773     ASSERT_EQ(a + 1, *result);
774 }
775 
TEST_F(SafeInterfaceTest,TestIncrementInt32)776 TEST_F(SafeInterfaceTest, TestIncrementInt32) {
777     const int32_t a = 1;
778     int32_t aPlusOne = 0;
779     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
780     ASSERT_EQ(NO_ERROR, result);
781     ASSERT_EQ(a + 1, aPlusOne);
782 }
783 
TEST_F(SafeInterfaceTest,TestIncrementUint32)784 TEST_F(SafeInterfaceTest, TestIncrementUint32) {
785     const uint32_t a = 1;
786     uint32_t aPlusOne = 0;
787     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
788     ASSERT_EQ(NO_ERROR, result);
789     ASSERT_EQ(a + 1, aPlusOne);
790 }
791 
TEST_F(SafeInterfaceTest,TestIncrementInt64)792 TEST_F(SafeInterfaceTest, TestIncrementInt64) {
793     const int64_t a = 1;
794     int64_t aPlusOne = 0;
795     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
796     ASSERT_EQ(NO_ERROR, result);
797     ASSERT_EQ(a + 1, aPlusOne);
798 }
799 
TEST_F(SafeInterfaceTest,TestIncrementUint64)800 TEST_F(SafeInterfaceTest, TestIncrementUint64) {
801     const uint64_t a = 1;
802     uint64_t aPlusOne = 0;
803     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
804     ASSERT_EQ(NO_ERROR, result);
805     ASSERT_EQ(a + 1, aPlusOne);
806 }
807 
TEST_F(SafeInterfaceTest,TestIncrementFloat)808 TEST_F(SafeInterfaceTest, TestIncrementFloat) {
809     const float a = 1.0f;
810     float aPlusOne = 0.0f;
811     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
812     ASSERT_EQ(NO_ERROR, result);
813     ASSERT_EQ(a + 1.0f, aPlusOne);
814 }
815 
TEST_F(SafeInterfaceTest,TestIncrementTwo)816 TEST_F(SafeInterfaceTest, TestIncrementTwo) {
817     const int32_t a = 1;
818     int32_t aPlusOne = 0;
819     const int32_t b = 2;
820     int32_t bPlusOne = 0;
821     status_t result = mSafeInterfaceTest->increment(1, &aPlusOne, 2, &bPlusOne);
822     ASSERT_EQ(NO_ERROR, result);
823     ASSERT_EQ(a + 1, aPlusOne);
824     ASSERT_EQ(b + 1, bPlusOne);
825 }
826 
main(int argc,char ** argv)827 extern "C" int main(int argc, char **argv) {
828     testing::InitGoogleTest(&argc, argv);
829 
830     if (fork() == 0) {
831         prctl(PR_SET_PDEATHSIG, SIGHUP);
832         sp<BnSafeInterfaceTest> nativeService = new BnSafeInterfaceTest;
833         status_t status = defaultServiceManager()->addService(kServiceName, nativeService);
834         if (status != OK) {
835             ALOG(LOG_INFO, "SafeInterfaceServer", "could not register");
836             return EXIT_FAILURE;
837         }
838         IPCThreadState::self()->joinThreadPool();
839         return EXIT_FAILURE;
840     }
841 
842     return RUN_ALL_TESTS();
843 }
844 
845 } // namespace tests
846 } // namespace android
847