1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define ATRACE_TAG ATRACE_TAG_ADB
18 #define LOG_TAG "PackageManagerShellCommandDataLoader-jni"
19 #include <android-base/file.h>
20 #include <android-base/logging.h>
21 #include <android-base/no_destructor.h>
22 #include <android-base/stringprintf.h>
23 #include <android-base/unique_fd.h>
24 #include <core_jni_helpers.h>
25 #include <cutils/multiuser.h>
26 #include <cutils/trace.h>
27 #include <endian.h>
28 #include <nativehelper/JNIHelp.h>
29 #include <sys/eventfd.h>
30 #include <sys/poll.h>
31 
32 #include <charconv>
33 #include <chrono>
34 #include <span>
35 #include <string>
36 #include <thread>
37 #include <unordered_map>
38 #include <unordered_set>
39 
40 #include "dataloader.h"
41 
42 // #define VERBOSE_READ_LOGS
43 
44 namespace android {
45 
46 namespace {
47 
48 using android::base::borrowed_fd;
49 using android::base::unique_fd;
50 
51 using namespace std::literals;
52 
53 using BlockSize = int16_t;
54 using FileIdx = int16_t;
55 using BlockIdx = int32_t;
56 using NumBlocks = int32_t;
57 using BlockType = int8_t;
58 using CompressionType = int8_t;
59 using RequestType = int16_t;
60 using MagicType = uint32_t;
61 
62 static constexpr int BUFFER_SIZE = 256 * 1024;
63 static constexpr int BLOCKS_COUNT = BUFFER_SIZE / INCFS_DATA_FILE_BLOCK_SIZE;
64 
65 static constexpr int COMMAND_SIZE = 4 + 2 + 2 + 4; // bytes
66 static constexpr int HEADER_SIZE = 2 + 1 + 1 + 4 + 2; // bytes
67 static constexpr std::string_view OKAY = "OKAY"sv;
68 static constexpr MagicType INCR = 0x52434e49; // BE INCR
69 
70 static constexpr auto PollTimeoutMs = 5000;
71 static constexpr auto TraceTagCheckInterval = 1s;
72 
73 static constexpr auto WaitOnEofMinInterval = 10ms;
74 static constexpr auto WaitOnEofMaxInterval = 1s;
75 
76 struct JniIds {
77     jclass packageManagerShellCommandDataLoader;
78     jmethodID pmscdLookupShellCommand;
79     jmethodID pmscdGetStdIn;
80     jmethodID pmscdGetLocalFile;
81 
JniIdsandroid::__anon5cfad2040111::JniIds82     JniIds(JNIEnv* env) {
83         packageManagerShellCommandDataLoader = (jclass)env->NewGlobalRef(
84                 FindClassOrDie(env, "com/android/server/pm/PackageManagerShellCommandDataLoader"));
85         pmscdLookupShellCommand =
86                 GetStaticMethodIDOrDie(env, packageManagerShellCommandDataLoader,
87                                        "lookupShellCommand",
88                                        "(Ljava/lang/String;)Landroid/os/ShellCommand;");
89         pmscdGetStdIn = GetStaticMethodIDOrDie(env, packageManagerShellCommandDataLoader,
90                                                "getStdIn", "(Landroid/os/ShellCommand;)I");
91         pmscdGetLocalFile =
92                 GetStaticMethodIDOrDie(env, packageManagerShellCommandDataLoader, "getLocalFile",
93                                        "(Landroid/os/ShellCommand;Ljava/lang/String;)I");
94     }
95 };
96 
jniIds(JNIEnv * env)97 const JniIds& jniIds(JNIEnv* env) {
98     static const JniIds ids(env);
99     return ids;
100 }
101 
102 struct BlockHeader {
103     FileIdx fileIdx = -1;
104     BlockType blockType = -1;
105     CompressionType compressionType = -1;
106     BlockIdx blockIdx = -1;
107     BlockSize blockSize = -1;
108 } __attribute__((packed));
109 
110 static_assert(sizeof(BlockHeader) == HEADER_SIZE);
111 
112 static constexpr RequestType EXIT = 0;
113 static constexpr RequestType BLOCK_MISSING = 1;
114 static constexpr RequestType PREFETCH = 2;
115 
116 struct RequestCommand {
117     MagicType magic;
118     RequestType requestType;
119     FileIdx fileIdx;
120     BlockIdx blockIdx;
121 } __attribute__((packed));
122 
123 static_assert(COMMAND_SIZE == sizeof(RequestCommand));
124 
sendRequest(int fd,RequestType requestType,FileIdx fileIdx=-1,BlockIdx blockIdx=-1)125 static bool sendRequest(int fd, RequestType requestType, FileIdx fileIdx = -1,
126                         BlockIdx blockIdx = -1) {
127     const RequestCommand command{.magic = INCR,
128                                  .requestType = static_cast<int16_t>(be16toh(requestType)),
129                                  .fileIdx = static_cast<int16_t>(be16toh(fileIdx)),
130                                  .blockIdx = static_cast<int32_t>(be32toh(blockIdx))};
131     return android::base::WriteFully(fd, &command, sizeof(command));
132 }
133 
readChunk(int fd,std::vector<uint8_t> & data)134 static bool readChunk(int fd, std::vector<uint8_t>& data) {
135     int32_t size;
136     if (!android::base::ReadFully(fd, &size, sizeof(size))) {
137         return false;
138     }
139     size = int32_t(be32toh(size));
140     if (size <= 0) {
141         return false;
142     }
143     data.resize(size);
144     return android::base::ReadFully(fd, data.data(), data.size());
145 }
146 
147 BlockHeader readHeader(std::span<uint8_t>& data);
148 
readLEInt32(borrowed_fd fd)149 static inline int32_t readLEInt32(borrowed_fd fd) {
150     int32_t result;
151     ReadFully(fd, &result, sizeof(result));
152     result = int32_t(le32toh(result));
153     return result;
154 }
155 
skipBytes(borrowed_fd fd,int * max_size)156 static inline bool skipBytes(borrowed_fd fd, int* max_size) {
157     int32_t size = std::min(readLEInt32(fd), *max_size);
158     if (size <= 0) {
159         return false;
160     }
161     *max_size -= size;
162     return (TEMP_FAILURE_RETRY(lseek64(fd.get(), size, SEEK_CUR)) >= 0);
163 }
164 
skipIdSigHeaders(borrowed_fd fd)165 static inline int32_t skipIdSigHeaders(borrowed_fd fd) {
166     // version
167     auto version = readLEInt32(fd);
168     int max_size = INCFS_MAX_SIGNATURE_SIZE - sizeof(version);
169     // hashingInfo and signingInfo
170     if (!skipBytes(fd, &max_size) || !skipBytes(fd, &max_size)) {
171         return -1;
172     }
173     return readLEInt32(fd); // size of the verity tree
174 }
175 
verityTreeSizeForFile(IncFsSize fileSize)176 static inline IncFsSize verityTreeSizeForFile(IncFsSize fileSize) {
177     constexpr int SHA256_DIGEST_SIZE = 32;
178     constexpr int digest_size = SHA256_DIGEST_SIZE;
179     constexpr int hash_per_block = INCFS_DATA_FILE_BLOCK_SIZE / digest_size;
180 
181     IncFsSize total_tree_block_count = 0;
182 
183     auto block_count = 1 + (fileSize - 1) / INCFS_DATA_FILE_BLOCK_SIZE;
184     auto hash_block_count = block_count;
185     while (hash_block_count > 1) {
186         hash_block_count = (hash_block_count + hash_per_block - 1) / hash_per_block;
187         total_tree_block_count += hash_block_count;
188     }
189     return total_tree_block_count * INCFS_DATA_FILE_BLOCK_SIZE;
190 }
191 
192 enum MetadataMode : int8_t {
193     STDIN = 0,
194     LOCAL_FILE = 1,
195     DATA_ONLY_STREAMING = 2,
196     STREAMING = 3,
197 };
198 
199 struct InputDesc {
200     unique_fd fd;
201     IncFsSize size;
202     IncFsBlockKind kind = INCFS_BLOCK_KIND_DATA;
203     bool waitOnEof = false;
204     bool streaming = false;
205     MetadataMode mode = STDIN;
206 };
207 using InputDescs = std::vector<InputDesc>;
208 
209 template <class T>
read(IncFsSpan & data)210 std::optional<T> read(IncFsSpan& data) {
211     if (data.size < (int32_t)sizeof(T)) {
212         return {};
213     }
214     T res;
215     memcpy(&res, data.data, sizeof(res));
216     data.data += sizeof(res);
217     data.size -= sizeof(res);
218     return res;
219 }
220 
openLocalFile(JNIEnv * env,const JniIds & jni,jobject shellCommand,const std::string & path)221 static inline unique_fd openLocalFile(JNIEnv* env, const JniIds& jni, jobject shellCommand,
222                                       const std::string& path) {
223     if (shellCommand) {
224         return unique_fd{env->CallStaticIntMethod(jni.packageManagerShellCommandDataLoader,
225                                                   jni.pmscdGetLocalFile, shellCommand,
226                                                   env->NewStringUTF(path.c_str()))};
227     }
228     auto fd = unique_fd(::open(path.c_str(), O_RDONLY | O_CLOEXEC));
229     if (!fd.ok()) {
230         PLOG(ERROR) << "Failed to open file: " << path << ", error code: " << fd.get();
231     }
232     return fd;
233 }
234 
openLocalFile(JNIEnv * env,const JniIds & jni,jobject shellCommand,IncFsSize size,const std::string & filePath)235 static inline InputDescs openLocalFile(JNIEnv* env, const JniIds& jni, jobject shellCommand,
236                                        IncFsSize size, const std::string& filePath) {
237     InputDescs result;
238     result.reserve(2);
239 
240     const std::string idsigPath = filePath + ".idsig";
241 
242     unique_fd idsigFd = openLocalFile(env, jni, shellCommand, idsigPath);
243     if (idsigFd.ok()) {
244         auto actualTreeSize = skipIdSigHeaders(idsigFd);
245         if (actualTreeSize < 0) {
246             ALOGE("Error reading .idsig file: wrong format.");
247             return {};
248         }
249         auto treeSize = verityTreeSizeForFile(size);
250         if (treeSize != actualTreeSize) {
251             ALOGE("Verity tree size mismatch: %d vs .idsig: %d.", int(treeSize),
252                   int(actualTreeSize));
253             return {};
254         }
255         result.push_back(InputDesc{
256                 .fd = std::move(idsigFd),
257                 .size = treeSize,
258                 .kind = INCFS_BLOCK_KIND_HASH,
259         });
260     }
261 
262     unique_fd fileFd = openLocalFile(env, jni, shellCommand, filePath);
263     if (fileFd.ok()) {
264         result.push_back(InputDesc{
265                 .fd = std::move(fileFd),
266                 .size = size,
267         });
268     }
269 
270     return result;
271 }
272 
openInputs(JNIEnv * env,const JniIds & jni,jobject shellCommand,IncFsSize size,IncFsSpan metadata)273 static inline InputDescs openInputs(JNIEnv* env, const JniIds& jni, jobject shellCommand,
274                                     IncFsSize size, IncFsSpan metadata) {
275     auto mode = read<int8_t>(metadata).value_or(STDIN);
276     if (mode == LOCAL_FILE) {
277         // local file and possibly signature
278         auto dataSize = le32toh(read<int32_t>(metadata).value_or(0));
279         return openLocalFile(env, jni, shellCommand, size, std::string(metadata.data, dataSize));
280     }
281 
282     if (!shellCommand) {
283         ALOGE("Missing shell command.");
284         return {};
285     }
286 
287     unique_fd fd{env->CallStaticIntMethod(jni.packageManagerShellCommandDataLoader,
288                                           jni.pmscdGetStdIn, shellCommand)};
289     if (!fd.ok()) {
290         return {};
291     }
292 
293     InputDescs result;
294     switch (mode) {
295         case STDIN: {
296             result.push_back(InputDesc{
297                     .fd = std::move(fd),
298                     .size = size,
299                     .waitOnEof = true,
300             });
301             break;
302         }
303         case DATA_ONLY_STREAMING: {
304             // verity tree from stdin, rest is streaming
305             auto treeSize = verityTreeSizeForFile(size);
306             result.push_back(InputDesc{
307                     .fd = std::move(fd),
308                     .size = treeSize,
309                     .kind = INCFS_BLOCK_KIND_HASH,
310                     .waitOnEof = true,
311                     .streaming = true,
312                     .mode = DATA_ONLY_STREAMING,
313             });
314             break;
315         }
316         case STREAMING: {
317             result.push_back(InputDesc{
318                     .fd = std::move(fd),
319                     .size = 0,
320                     .streaming = true,
321                     .mode = STREAMING,
322             });
323             break;
324         }
325     }
326     return result;
327 }
328 
329 class PMSCDataLoader;
330 
331 struct OnTraceChanged {
332     OnTraceChanged();
~OnTraceChangedandroid::__anon5cfad2040111::OnTraceChanged333     ~OnTraceChanged() {
334         mRunning = false;
335         mChecker.join();
336     }
337 
registerCallbackandroid::__anon5cfad2040111::OnTraceChanged338     void registerCallback(PMSCDataLoader* callback) {
339         std::unique_lock lock(mMutex);
340         mCallbacks.insert(callback);
341     }
342 
unregisterCallbackandroid::__anon5cfad2040111::OnTraceChanged343     void unregisterCallback(PMSCDataLoader* callback) {
344         std::unique_lock lock(mMutex);
345         mCallbacks.erase(callback);
346     }
347 
348 private:
349     std::mutex mMutex;
350     std::unordered_set<PMSCDataLoader*> mCallbacks;
351     std::atomic<bool> mRunning{true};
352     std::thread mChecker;
353 };
354 
onTraceChanged()355 static OnTraceChanged& onTraceChanged() {
356     static android::base::NoDestructor<OnTraceChanged> instance;
357     return *instance;
358 }
359 
360 class PMSCDataLoader : public android::dataloader::DataLoader {
361 public:
PMSCDataLoader(JavaVM * jvm)362     PMSCDataLoader(JavaVM* jvm) : mJvm(jvm) { CHECK(mJvm); }
~PMSCDataLoader()363     ~PMSCDataLoader() {
364         onTraceChanged().unregisterCallback(this);
365         if (mReceiverThread.joinable()) {
366             mReceiverThread.join();
367         }
368     }
369 
updateReadLogsState(const bool enabled)370     void updateReadLogsState(const bool enabled) {
371         if (enabled != mReadLogsEnabled.exchange(enabled)) {
372             mIfs->setParams({.readLogsEnabled = enabled});
373         }
374     }
375 
376 private:
377     // Bitmask of supported features.
getFeatures() const378     DataLoaderFeatures getFeatures() const final { return DATA_LOADER_FEATURE_UID; }
379 
380     // Lifecycle.
onCreate(const android::dataloader::DataLoaderParams & params,android::dataloader::FilesystemConnectorPtr ifs,android::dataloader::StatusListenerPtr statusListener,android::dataloader::ServiceConnectorPtr,android::dataloader::ServiceParamsPtr)381     bool onCreate(const android::dataloader::DataLoaderParams& params,
382                   android::dataloader::FilesystemConnectorPtr ifs,
383                   android::dataloader::StatusListenerPtr statusListener,
384                   android::dataloader::ServiceConnectorPtr,
385                   android::dataloader::ServiceParamsPtr) final {
386         CHECK(ifs) << "ifs can't be null";
387         CHECK(statusListener) << "statusListener can't be null";
388         mArgs = params.arguments();
389         mIfs = ifs;
390         mStatusListener = statusListener;
391         updateReadLogsState(atrace_is_tag_enabled(ATRACE_TAG));
392         onTraceChanged().registerCallback(this);
393         return true;
394     }
onStart()395     bool onStart() final { return true; }
onStop()396     void onStop() final {
397         mStopReceiving = true;
398         eventfd_write(mEventFd, 1);
399         if (mReceiverThread.joinable()) {
400             mReceiverThread.join();
401         }
402     }
onDestroy()403     void onDestroy() final {}
404 
405     // Installation.
onPrepareImage(dataloader::DataLoaderInstallationFiles addedFiles)406     bool onPrepareImage(dataloader::DataLoaderInstallationFiles addedFiles) final {
407         ALOGE("onPrepareImage: start.");
408 
409         JNIEnv* env = GetOrAttachJNIEnvironment(mJvm, JNI_VERSION_1_6);
410         const auto& jni = jniIds(env);
411 
412         jobject shellCommand = env->CallStaticObjectMethod(jni.packageManagerShellCommandDataLoader,
413                                                            jni.pmscdLookupShellCommand,
414                                                            env->NewStringUTF(mArgs.c_str()));
415 
416         std::vector<char> buffer;
417         buffer.reserve(BUFFER_SIZE);
418 
419         std::vector<IncFsDataBlock> blocks;
420         blocks.reserve(BLOCKS_COUNT);
421 
422         unique_fd streamingFd;
423         MetadataMode streamingMode;
424         for (auto&& file : addedFiles) {
425             auto inputs = openInputs(env, jni, shellCommand, file.size, file.metadata);
426             if (inputs.empty()) {
427                 ALOGE("Failed to open an input file for metadata: %.*s, final file name is: %s. "
428                       "Error %d",
429                       int(file.metadata.size), file.metadata.data, file.name, errno);
430                 return false;
431             }
432 
433             const auto fileId = IncFs_FileIdFromMetadata(file.metadata);
434             const base::unique_fd incfsFd(mIfs->openForSpecialOps(fileId).release());
435             if (incfsFd < 0) {
436                 ALOGE("Failed to open an IncFS file for metadata: %.*s, final file name is: %s. "
437                       "Error %d",
438                       int(file.metadata.size), file.metadata.data, file.name, errno);
439                 return false;
440             }
441 
442             for (auto&& input : inputs) {
443                 if (input.streaming && !streamingFd.ok()) {
444                     streamingFd.reset(dup(input.fd));
445                     streamingMode = input.mode;
446                 }
447                 if (!copyToIncFs(incfsFd, input.size, input.kind, input.fd, input.waitOnEof,
448                                  &buffer, &blocks)) {
449                     ALOGE("Failed to copy data to IncFS file for metadata: %.*s, final file name "
450                           "is: %s. "
451                           "Error %d",
452                           int(file.metadata.size), file.metadata.data, file.name, errno);
453                     return false;
454                 }
455             }
456         }
457 
458         if (streamingFd.ok()) {
459             ALOGE("onPrepareImage: done, proceeding to streaming.");
460             return initStreaming(std::move(streamingFd), streamingMode);
461         }
462 
463         ALOGE("onPrepareImage: done.");
464         return true;
465     }
466 
copyToIncFs(borrowed_fd incfsFd,IncFsSize size,IncFsBlockKind kind,borrowed_fd incomingFd,bool waitOnEof,std::vector<char> * buffer,std::vector<IncFsDataBlock> * blocks)467     bool copyToIncFs(borrowed_fd incfsFd, IncFsSize size, IncFsBlockKind kind,
468                      borrowed_fd incomingFd, bool waitOnEof, std::vector<char>* buffer,
469                      std::vector<IncFsDataBlock>* blocks) {
470         IncFsSize remaining = size;
471         IncFsBlockIndex blockIdx = 0;
472         while (remaining > 0) {
473             constexpr auto capacity = BUFFER_SIZE;
474             auto size = buffer->size();
475             if (capacity - size < INCFS_DATA_FILE_BLOCK_SIZE) {
476                 if (!flashToIncFs(incfsFd, kind, false, &blockIdx, buffer, blocks)) {
477                     return false;
478                 }
479                 continue;
480             }
481 
482             auto toRead = std::min<IncFsSize>(remaining, capacity - size);
483             buffer->resize(size + toRead);
484             auto read = ::read(incomingFd.get(), buffer->data() + size, toRead);
485             if (read == 0) {
486                 if (waitOnEof) {
487                     // eof of stdin, waiting...
488                     if (doWaitOnEof()) {
489                         continue;
490                     } else {
491                         return false;
492                     }
493                 }
494                 break;
495             }
496             resetWaitOnEof();
497 
498             if (read < 0) {
499                 return false;
500             }
501 
502             buffer->resize(size + read);
503             remaining -= read;
504         }
505         if (!buffer->empty() && !flashToIncFs(incfsFd, kind, true, &blockIdx, buffer, blocks)) {
506             return false;
507         }
508         return true;
509     }
510 
flashToIncFs(borrowed_fd incfsFd,IncFsBlockKind kind,bool eof,IncFsBlockIndex * blockIdx,std::vector<char> * buffer,std::vector<IncFsDataBlock> * blocks)511     bool flashToIncFs(borrowed_fd incfsFd, IncFsBlockKind kind, bool eof, IncFsBlockIndex* blockIdx,
512                       std::vector<char>* buffer, std::vector<IncFsDataBlock>* blocks) {
513         int consumed = 0;
514         const auto fullBlocks = buffer->size() / INCFS_DATA_FILE_BLOCK_SIZE;
515         for (int i = 0; i < fullBlocks; ++i) {
516             const auto inst = IncFsDataBlock{
517                     .fileFd = incfsFd.get(),
518                     .pageIndex = (*blockIdx)++,
519                     .compression = INCFS_COMPRESSION_KIND_NONE,
520                     .kind = kind,
521                     .dataSize = INCFS_DATA_FILE_BLOCK_SIZE,
522                     .data = buffer->data() + consumed,
523             };
524             blocks->push_back(inst);
525             consumed += INCFS_DATA_FILE_BLOCK_SIZE;
526         }
527         const auto remain = buffer->size() - fullBlocks * INCFS_DATA_FILE_BLOCK_SIZE;
528         if (remain && eof) {
529             const auto inst = IncFsDataBlock{
530                     .fileFd = incfsFd.get(),
531                     .pageIndex = (*blockIdx)++,
532                     .compression = INCFS_COMPRESSION_KIND_NONE,
533                     .kind = kind,
534                     .dataSize = static_cast<uint16_t>(remain),
535                     .data = buffer->data() + consumed,
536             };
537             blocks->push_back(inst);
538             consumed += remain;
539         }
540 
541         auto res = mIfs->writeBlocks({blocks->data(), blocks->size()});
542 
543         blocks->clear();
544         buffer->erase(buffer->begin(), buffer->begin() + consumed);
545 
546         if (res < 0) {
547             ALOGE("Failed to write block to IncFS: %d", int(res));
548             return false;
549         }
550         return true;
551     }
552 
553     enum class WaitResult {
554         DataAvailable,
555         Timeout,
556         Failure,
557         StopRequested,
558     };
559 
waitForData(int fd)560     WaitResult waitForData(int fd) {
561         using Clock = std::chrono::steady_clock;
562         using Milliseconds = std::chrono::milliseconds;
563 
564         auto pollTimeoutMs = PollTimeoutMs;
565         const auto waitEnd = Clock::now() + Milliseconds(pollTimeoutMs);
566         while (!mStopReceiving) {
567             struct pollfd pfds[2] = {{fd, POLLIN, 0}, {mEventFd, POLLIN, 0}};
568             // Wait until either data is ready or stop signal is received
569             int res = poll(pfds, std::size(pfds), pollTimeoutMs);
570 
571             if (res < 0) {
572                 if (errno == EINTR) {
573                     pollTimeoutMs = std::chrono::duration_cast<Milliseconds>(waitEnd - Clock::now())
574                                             .count();
575                     if (pollTimeoutMs < 0) {
576                         return WaitResult::Timeout;
577                     }
578                     continue;
579                 }
580                 ALOGE("Failed to poll. Error %d", errno);
581                 return WaitResult::Failure;
582             }
583 
584             if (res == 0) {
585                 return WaitResult::Timeout;
586             }
587 
588             // First check if there is a stop signal
589             if (pfds[1].revents == POLLIN) {
590                 ALOGE("DataLoader requested to stop.");
591                 return WaitResult::StopRequested;
592             }
593             // Otherwise check if incoming data is ready
594             if (pfds[0].revents == POLLIN) {
595                 return WaitResult::DataAvailable;
596             }
597 
598             // Invalid case, just fail.
599             ALOGE("Failed to poll. Result %d", res);
600             return WaitResult::Failure;
601         }
602 
603         ALOGE("DataLoader requested to stop.");
604         return WaitResult::StopRequested;
605     }
606 
607     // Streaming.
initStreaming(unique_fd inout,MetadataMode mode)608     bool initStreaming(unique_fd inout, MetadataMode mode) {
609         mEventFd.reset(eventfd(0, EFD_CLOEXEC));
610         if (mEventFd < 0) {
611             ALOGE("Failed to create eventfd.");
612             return false;
613         }
614 
615         // Awaiting adb handshake.
616         if (waitForData(inout) != WaitResult::DataAvailable) {
617             ALOGE("Failure waiting for the handshake.");
618             return false;
619         }
620 
621         char okay_buf[OKAY.size()];
622         if (!android::base::ReadFully(inout, okay_buf, OKAY.size())) {
623             ALOGE("Failed to receive OKAY. Abort. Error %d", errno);
624             return false;
625         }
626         if (std::string_view(okay_buf, OKAY.size()) != OKAY) {
627             ALOGE("Received '%.*s', expecting '%.*s'", (int)OKAY.size(), okay_buf, (int)OKAY.size(),
628                   OKAY.data());
629             return false;
630         }
631 
632         {
633             std::lock_guard lock{mOutFdLock};
634             mOutFd.reset(::dup(inout));
635             if (mOutFd < 0) {
636                 ALOGE("Failed to create streaming fd.");
637             }
638         }
639 
640         if (mStopReceiving) {
641             ALOGE("DataLoader requested to stop.");
642             return false;
643         }
644 
645         mReceiverThread = std::thread(
646                 [this, io = std::move(inout), mode]() mutable { receiver(std::move(io), mode); });
647 
648         ALOGI("Started streaming...");
649         return true;
650     }
651 
652     // IFS callbacks.
onPendingReads(dataloader::PendingReads pendingReads)653     void onPendingReads(dataloader::PendingReads pendingReads) final {}
onPageReads(dataloader::PageReads pageReads)654     void onPageReads(dataloader::PageReads pageReads) final {}
655 
onPendingReadsWithUid(dataloader::PendingReadsWithUid pendingReads)656     void onPendingReadsWithUid(dataloader::PendingReadsWithUid pendingReads) final {
657         std::lock_guard lock{mOutFdLock};
658         if (mOutFd < 0) {
659             return;
660         }
661         CHECK(mIfs);
662         for (auto&& pendingRead : pendingReads) {
663             const android::dataloader::FileId& fileId = pendingRead.id;
664             const auto blockIdx = static_cast<BlockIdx>(pendingRead.block);
665             /*
666             ALOGI("Missing: %d", (int) blockIdx);
667             */
668             FileIdx fileIdx = convertFileIdToFileIndex(fileId);
669             if (fileIdx < 0) {
670                 ALOGE("Failed to handle event for fileid=%s. Ignore.",
671                       android::incfs::toString(fileId).c_str());
672                 continue;
673             }
674             if (mRequestedFiles.insert(fileIdx).second &&
675                 !sendRequest(mOutFd, PREFETCH, fileIdx, blockIdx)) {
676                 mRequestedFiles.erase(fileIdx);
677             }
678             sendRequest(mOutFd, BLOCK_MISSING, fileIdx, blockIdx);
679         }
680     }
681 
682     // Read tracing.
683     struct TracedRead {
684         uint64_t timestampUs;
685         android::dataloader::FileId fileId;
686         android::dataloader::Uid uid;
687         uint32_t firstBlockIdx;
688         uint32_t count;
689     };
690 
onPageReadsWithUid(dataloader::PageReadsWithUid pageReads)691     void onPageReadsWithUid(dataloader::PageReadsWithUid pageReads) final {
692         if (!pageReads.size()) {
693             return;
694         }
695 
696         auto trace = atrace_is_tag_enabled(ATRACE_TAG);
697         if (CC_LIKELY(!trace)) {
698             return;
699         }
700 
701         TracedRead last = {};
702         auto lastSerialNo = mLastSerialNo < 0 ? pageReads[0].serialNo : mLastSerialNo;
703         for (auto&& read : pageReads) {
704             const auto expectedSerialNo = lastSerialNo + last.count;
705 #ifdef VERBOSE_READ_LOGS
706             {
707                 FileIdx fileIdx = convertFileIdToFileIndex(read.id);
708 
709                 auto appId = multiuser_get_app_id(read.uid);
710                 auto userId = multiuser_get_user_id(read.uid);
711                 auto trace = android::base::
712                         StringPrintf("verbose_page_read: serialNo=%lld (expected=%lld) index=%lld "
713                                      "file=%d appid=%d userid=%d",
714                                      static_cast<long long>(read.serialNo),
715                                      static_cast<long long>(expectedSerialNo),
716                                      static_cast<long long>(read.block), static_cast<int>(fileIdx),
717                                      static_cast<int>(appId), static_cast<int>(userId));
718 
719                 ATRACE_BEGIN(trace.c_str());
720                 ATRACE_END();
721             }
722 #endif // VERBOSE_READ_LOGS
723 
724             if (read.serialNo == expectedSerialNo && read.id == last.fileId &&
725                 read.uid == last.uid && read.block == last.firstBlockIdx + last.count) {
726                 ++last.count;
727                 continue;
728             }
729 
730             // First, trace the reads.
731             traceRead(last);
732 
733             // Second, report missing reads, if any.
734             if (read.serialNo != expectedSerialNo) {
735                 traceMissingReads(expectedSerialNo, read.serialNo);
736             }
737 
738             last = TracedRead{
739                     .timestampUs = read.bootClockTsUs,
740                     .fileId = read.id,
741                     .uid = read.uid,
742                     .firstBlockIdx = (uint32_t)read.block,
743                     .count = 1,
744             };
745             lastSerialNo = read.serialNo;
746         }
747 
748         traceRead(last);
749         mLastSerialNo = lastSerialNo + last.count;
750     }
751 
traceRead(const TracedRead & read)752     void traceRead(const TracedRead& read) {
753         if (!read.count) {
754             return;
755         }
756 
757         FileIdx fileIdx = convertFileIdToFileIndex(read.fileId);
758 
759         std::string trace;
760         if (read.uid != kIncFsNoUid) {
761             auto appId = multiuser_get_app_id(read.uid);
762             auto userId = multiuser_get_user_id(read.uid);
763             trace = android::base::
764                     StringPrintf("page_read: index=%lld count=%lld file=%d appid=%d userid=%d",
765                                  static_cast<long long>(read.firstBlockIdx),
766                                  static_cast<long long>(read.count), static_cast<int>(fileIdx),
767                                  static_cast<int>(appId), static_cast<int>(userId));
768         } else {
769             trace = android::base::StringPrintf("page_read: index=%lld count=%lld file=%d",
770                                                 static_cast<long long>(read.firstBlockIdx),
771                                                 static_cast<long long>(read.count),
772                                                 static_cast<int>(fileIdx));
773         }
774 
775         ATRACE_BEGIN(trace.c_str());
776         ATRACE_END();
777     }
778 
traceMissingReads(int64_t expectedSerialNo,int64_t readSerialNo)779     void traceMissingReads(int64_t expectedSerialNo, int64_t readSerialNo) {
780         const auto readsMissing = readSerialNo - expectedSerialNo;
781         const auto trace =
782                 android::base::StringPrintf("missing_page_reads: count=%lld, range [%lld,%lld)",
783                                             static_cast<long long>(readsMissing),
784                                             static_cast<long long>(expectedSerialNo),
785                                             static_cast<long long>(readSerialNo));
786         ATRACE_BEGIN(trace.c_str());
787         ATRACE_END();
788     }
789 
receiver(unique_fd inout,MetadataMode mode)790     void receiver(unique_fd inout, MetadataMode mode) {
791         std::vector<uint8_t> data;
792         std::vector<IncFsDataBlock> instructions;
793         std::unordered_map<FileIdx, unique_fd> writeFds;
794         while (!mStopReceiving) {
795             const auto res = waitForData(inout);
796             if (res == WaitResult::Timeout) {
797                 continue;
798             }
799             if (res == WaitResult::Failure) {
800                 mStatusListener->reportStatus(DATA_LOADER_UNRECOVERABLE);
801                 break;
802             }
803             if (res == WaitResult::StopRequested) {
804                 ALOGE("Sending EXIT to server.");
805                 sendRequest(inout, EXIT);
806                 break;
807             }
808             if (!readChunk(inout, data)) {
809                 ALOGE("Failed to read a message. Abort.");
810                 mStatusListener->reportStatus(DATA_LOADER_UNRECOVERABLE);
811                 break;
812             }
813             auto remainingData = std::span(data);
814             while (!remainingData.empty()) {
815                 auto header = readHeader(remainingData);
816                 if (header.fileIdx == -1 && header.blockType == 0 && header.compressionType == 0 &&
817                     header.blockIdx == 0 && header.blockSize == 0) {
818                     ALOGI("Stop command received. Sending exit command (remaining bytes: %d).",
819                           int(remainingData.size()));
820 
821                     sendRequest(inout, EXIT);
822                     mStopReceiving = true;
823                     break;
824                 }
825                 if (header.fileIdx < 0 || header.blockSize <= 0 || header.blockType < 0 ||
826                     header.compressionType < 0 || header.blockIdx < 0) {
827                     ALOGE("Invalid header received. Abort.");
828                     mStopReceiving = true;
829                     break;
830                 }
831 
832                 const FileIdx fileIdx = header.fileIdx;
833                 const android::dataloader::FileId fileId = convertFileIndexToFileId(mode, fileIdx);
834                 if (!android::incfs::isValidFileId(fileId)) {
835                     ALOGE("Unknown data destination for file ID %d. Ignore.", header.fileIdx);
836                     continue;
837                 }
838 
839                 auto& writeFd = writeFds[fileIdx];
840                 if (writeFd < 0) {
841                     writeFd.reset(this->mIfs->openForSpecialOps(fileId).release());
842                     if (writeFd < 0) {
843                         ALOGE("Failed to open file %d for writing (%d). Abort.", header.fileIdx,
844                               -writeFd);
845                         break;
846                     }
847                 }
848 
849                 const auto inst = IncFsDataBlock{
850                         .fileFd = writeFd,
851                         .pageIndex = static_cast<IncFsBlockIndex>(header.blockIdx),
852                         .compression = static_cast<IncFsCompressionKind>(header.compressionType),
853                         .kind = static_cast<IncFsBlockKind>(header.blockType),
854                         .dataSize = static_cast<uint16_t>(header.blockSize),
855                         .data = (const char*)remainingData.data(),
856                 };
857                 instructions.push_back(inst);
858                 remainingData = remainingData.subspan(header.blockSize);
859             }
860             writeInstructions(instructions);
861         }
862         writeInstructions(instructions);
863 
864         {
865             std::lock_guard lock{mOutFdLock};
866             mOutFd.reset();
867         }
868     }
869 
writeInstructions(std::vector<IncFsDataBlock> & instructions)870     void writeInstructions(std::vector<IncFsDataBlock>& instructions) {
871         auto res = this->mIfs->writeBlocks(instructions);
872         if (res != instructions.size()) {
873             ALOGE("Dailed to write data to Incfs (res=%d when expecting %d)", res,
874                   int(instructions.size()));
875         }
876         instructions.clear();
877     }
878 
convertFileIdToFileIndex(android::dataloader::FileId fileId)879     FileIdx convertFileIdToFileIndex(android::dataloader::FileId fileId) {
880         // FileId has format '\2FileIdx'.
881         const char* meta = (const char*)&fileId;
882 
883         int8_t mode = *meta;
884         if (mode != DATA_ONLY_STREAMING && mode != STREAMING) {
885             return -1;
886         }
887 
888         int fileIdx;
889         auto res = std::from_chars(meta + 1, meta + sizeof(fileId), fileIdx);
890         if (res.ec != std::errc{} || fileIdx < std::numeric_limits<FileIdx>::min() ||
891             fileIdx > std::numeric_limits<FileIdx>::max()) {
892             return -1;
893         }
894 
895         return FileIdx(fileIdx);
896     }
897 
convertFileIndexToFileId(MetadataMode mode,FileIdx fileIdx)898     android::dataloader::FileId convertFileIndexToFileId(MetadataMode mode, FileIdx fileIdx) {
899         IncFsFileId fileId = {};
900         char* meta = (char*)&fileId;
901         *meta = mode;
902         if (auto [p, ec] = std::to_chars(meta + 1, meta + sizeof(fileId), fileIdx);
903             ec != std::errc()) {
904             return {};
905         }
906         return fileId;
907     }
908 
909     // Waiting with exponential backoff, maximum total time ~1.2sec.
doWaitOnEof()910     bool doWaitOnEof() {
911         if (mWaitOnEofInterval >= WaitOnEofMaxInterval) {
912             resetWaitOnEof();
913             return false;
914         }
915         auto result = mWaitOnEofInterval;
916         mWaitOnEofInterval =
917                 std::min<std::chrono::milliseconds>(mWaitOnEofInterval * 2, WaitOnEofMaxInterval);
918         std::this_thread::sleep_for(result);
919         return true;
920     }
921 
resetWaitOnEof()922     void resetWaitOnEof() { mWaitOnEofInterval = WaitOnEofMinInterval; }
923 
924     JavaVM* const mJvm;
925     std::string mArgs;
926     android::dataloader::FilesystemConnectorPtr mIfs = nullptr;
927     android::dataloader::StatusListenerPtr mStatusListener = nullptr;
928     std::mutex mOutFdLock;
929     android::base::unique_fd mOutFd;
930     android::base::unique_fd mEventFd;
931     std::thread mReceiverThread;
932     std::atomic<bool> mStopReceiving = false;
933     std::atomic<bool> mReadLogsEnabled = false;
934     std::chrono::milliseconds mWaitOnEofInterval{WaitOnEofMinInterval};
935     int64_t mLastSerialNo{-1};
936     /** Tracks which files have been requested */
937     std::unordered_set<FileIdx> mRequestedFiles;
938 };
939 
OnTraceChanged()940 OnTraceChanged::OnTraceChanged() {
941     mChecker = std::thread([this]() {
942         bool oldTrace = atrace_is_tag_enabled(ATRACE_TAG);
943         while (mRunning) {
944             bool newTrace = atrace_is_tag_enabled(ATRACE_TAG);
945             if (oldTrace != newTrace) {
946                 std::unique_lock lock(mMutex);
947                 for (auto&& callback : mCallbacks) {
948                     callback->updateReadLogsState(newTrace);
949                 }
950             }
951             oldTrace = newTrace;
952             std::this_thread::sleep_for(TraceTagCheckInterval);
953         }
954     });
955 }
956 
readHeader(std::span<uint8_t> & data)957 BlockHeader readHeader(std::span<uint8_t>& data) {
958     BlockHeader header;
959     if (data.size() < sizeof(header)) {
960         return header;
961     }
962 
963     header.fileIdx = static_cast<FileIdx>(be16toh(*reinterpret_cast<const uint16_t*>(&data[0])));
964     header.blockType = static_cast<BlockType>(data[2]);
965     header.compressionType = static_cast<CompressionType>(data[3]);
966     header.blockIdx = static_cast<BlockIdx>(be32toh(*reinterpret_cast<const uint32_t*>(&data[4])));
967     header.blockSize =
968             static_cast<BlockSize>(be16toh(*reinterpret_cast<const uint16_t*>(&data[8])));
969     data = data.subspan(sizeof(header));
970 
971     return header;
972 }
973 
nativeInitialize(JNIEnv * env,jclass klass)974 static void nativeInitialize(JNIEnv* env, jclass klass) {
975     jniIds(env);
976 }
977 
978 static const JNINativeMethod method_table[] = {
979         {"nativeInitialize", "()V", (void*)nativeInitialize},
980 };
981 
982 } // namespace
983 
register_android_server_com_android_server_pm_PackageManagerShellCommandDataLoader(JNIEnv * env)984 int register_android_server_com_android_server_pm_PackageManagerShellCommandDataLoader(
985         JNIEnv* env) {
986     android::dataloader::DataLoader::initialize(
987             [](auto jvm, const auto& params) -> android::dataloader::DataLoaderPtr {
988                 if (params.type() == DATA_LOADER_TYPE_INCREMENTAL) {
989                     // This DataLoader only supports incremental installations.
990                     return std::make_unique<PMSCDataLoader>(jvm);
991                 }
992                 return {};
993             });
994     return jniRegisterNativeMethods(env,
995                                     "com/android/server/pm/PackageManagerShellCommandDataLoader",
996                                     method_table, NELEM(method_table));
997 }
998 
999 } // namespace android
1000