1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils.async;
18 
19 import android.os.ParcelFileDescriptor;
20 import android.system.StructPollfd;
21 import android.util.Log;
22 
23 import com.android.net.module.util.async.CircularByteBuffer;
24 import com.android.net.module.util.async.OsAccess;
25 
26 import java.io.FileDescriptor;
27 import java.io.InterruptedIOException;
28 import java.io.IOException;
29 import java.lang.reflect.Constructor;
30 import java.lang.reflect.Method;
31 import java.lang.reflect.Field;
32 import java.util.HashMap;
33 import java.util.concurrent.TimeUnit;
34 
35 public class FakeOsAccess extends OsAccess {
36     public static final boolean ENABLE_FINE_DEBUG = true;
37 
38     public static final int DEFAULT_FILE_DATA_QUEUE_SIZE = 8 * 1024;
39 
40     private enum FileType { PAIR, PIPE }
41 
42     // Common poll() constants:
43     private static final short POLLIN  = 0x0001;
44     private static final short POLLOUT = 0x0004;
45     private static final short POLLERR = 0x0008;
46     private static final short POLLHUP = 0x0010;
47 
48     private static final Constructor<FileDescriptor> FD_CONSTRUCTOR;
49     private static final Field FD_FIELD_DESCRIPTOR;
50     private static final Field PFD_FIELD_DESCRIPTOR;
51     private static final Field PFD_FIELD_GUARD;
52     private static final Method CLOSE_GUARD_METHOD_CLOSE;
53 
54     private final int mReadQueueSize = DEFAULT_FILE_DATA_QUEUE_SIZE;
55     private final int mWriteQueueSize = DEFAULT_FILE_DATA_QUEUE_SIZE;
56     private final HashMap<Integer, File> mFiles = new HashMap<>();
57     private final byte[] mTmpBuffer = new byte[1024];
58     private final long mStartTime;
59     private final String mLogTag;
60     private int mFileNumberGen = 3;
61     private boolean mHasRateLimitedData;
62 
FakeOsAccess(String logTag)63     public FakeOsAccess(String logTag) {
64         mLogTag = logTag;
65         mStartTime = monotonicTimeMillis();
66     }
67 
68     @Override
monotonicTimeMillis()69     public long monotonicTimeMillis() {
70         return System.nanoTime() / 1000000;
71     }
72 
73     @Override
getInnerFileDescriptor(ParcelFileDescriptor fd)74     public FileDescriptor getInnerFileDescriptor(ParcelFileDescriptor fd) {
75         try {
76             return (FileDescriptor) PFD_FIELD_DESCRIPTOR.get(fd);
77         } catch (Exception e) {
78             throw new RuntimeException(e);
79         }
80     }
81 
82     @Override
close(ParcelFileDescriptor fd)83     public void close(ParcelFileDescriptor fd) {
84         if (fd != null) {
85             close(getInnerFileDescriptor(fd));
86 
87             try {
88                 // Reduce CloseGuard warnings.
89                 Object guard = PFD_FIELD_GUARD.get(fd);
90                 CLOSE_GUARD_METHOD_CLOSE.invoke(guard);
91             } catch (Exception e) {
92                 throw new RuntimeException(e);
93             }
94         }
95     }
96 
close(FileDescriptor fd)97     public synchronized void close(FileDescriptor fd) {
98         if (fd != null) {
99             File file = getFileOrNull(fd);
100             if (file != null) {
101                 file.decreaseRefCount();
102                 mFiles.remove(getFileDescriptorNumber(fd));
103                 setFileDescriptorNumber(fd, -1);
104                 notifyAll();
105             }
106         }
107     }
108 
getFile(String func, FileDescriptor fd)109     private File getFile(String func, FileDescriptor fd) throws IOException {
110         File file = getFileOrNull(fd);
111         if (file == null) {
112             throw newIOException(func, "Unknown file descriptor: " + getFileDebugName(fd));
113         }
114         return file;
115     }
116 
getFileOrNull(FileDescriptor fd)117     private File getFileOrNull(FileDescriptor fd) {
118         return mFiles.get(getFileDescriptorNumber(fd));
119     }
120 
121     @Override
getFileDebugName(ParcelFileDescriptor fd)122     public String getFileDebugName(ParcelFileDescriptor fd) {
123         return (fd != null ? getFileDebugName(getInnerFileDescriptor(fd)) : "null");
124     }
125 
getFileDebugName(FileDescriptor fd)126     public String getFileDebugName(FileDescriptor fd) {
127         if (fd == null) {
128             return "null";
129         }
130 
131         final int fdNumber = getFileDescriptorNumber(fd);
132         File file = mFiles.get(fdNumber);
133 
134         StringBuilder sb = new StringBuilder();
135         if (file != null) {
136             if (file.name != null) {
137                 sb.append(file.name);
138                 sb.append("/");
139             }
140             sb.append(file.type);
141             sb.append("/");
142         } else {
143             sb.append("BADFD/");
144         }
145         sb.append(fdNumber);
146         return sb.toString();
147     }
148 
setFileName(FileDescriptor fd, String name)149     public synchronized void setFileName(FileDescriptor fd, String name) {
150         File file = getFileOrNull(fd);
151         if (file != null) {
152             file.name = name;
153         }
154     }
155 
156     @Override
setNonBlocking(FileDescriptor fd)157     public synchronized void setNonBlocking(FileDescriptor fd) throws IOException {
158         File file = getFile("fcntl", fd);
159         file.isBlocking = false;
160     }
161 
162     @Override
read(FileDescriptor fd, byte[] buffer, int pos, int len)163     public synchronized int read(FileDescriptor fd, byte[] buffer, int pos, int len)
164             throws IOException {
165         checkBoundaries("read", buffer, pos, len);
166 
167         File file = getFile("read", fd);
168         if (file.readQueue == null) {
169             throw newIOException("read", "File not readable");
170         }
171         file.checkNonBlocking("read");
172 
173         if (len == 0) {
174             return 0;
175         }
176 
177         final int availSize = file.readQueue.size();
178         if (availSize == 0) {
179             if (file.isEndOfStream) {
180                 // Java convention uses -1 to indicate end of stream.
181                 return -1;
182             }
183             return 0;  // EAGAIN
184         }
185 
186         final int readCount = Math.min(len, availSize);
187         file.readQueue.readBytes(buffer, pos, readCount);
188         maybeTransferData(file);
189         return readCount;
190     }
191 
192     @Override
write(FileDescriptor fd, byte[] buffer, int pos, int len)193     public synchronized int write(FileDescriptor fd, byte[] buffer, int pos, int len)
194             throws IOException {
195         checkBoundaries("write", buffer, pos, len);
196 
197         File file = getFile("write", fd);
198         if (file.writeQueue == null) {
199             throw newIOException("read", "File not writable");
200         }
201         if (file.type == FileType.PIPE && file.sink.openCount == 0) {
202             throw newIOException("write", "The other end of pipe is closed");
203         }
204         file.checkNonBlocking("write");
205 
206         if (len == 0) {
207             return 0;
208         }
209 
210         final int originalFreeSize = file.writeQueue.freeSize();
211         if (originalFreeSize == 0) {
212             return 0;  // EAGAIN
213         }
214 
215         final int writeCount = Math.min(len, originalFreeSize);
216         file.writeQueue.writeBytes(buffer, pos, writeCount);
217         maybeTransferData(file);
218 
219         if (file.writeQueue.freeSize() < originalFreeSize) {
220             final int additionalQueuedCount = originalFreeSize - file.writeQueue.freeSize();
221             Log.i(mLogTag, logStr("Delaying transfer of " + additionalQueuedCount
222                 + " bytes, queued=" + file.writeQueue.size() + ", type=" + file.type
223                 + ", src_red=" + file.outboundLimiter + ", dst_red=" + file.sink.inboundLimiter));
224         }
225 
226         return writeCount;
227     }
228 
maybeTransferData(File file)229     private void maybeTransferData(File file) {
230         boolean hasChanges = copyFileBuffers(file, file.sink);
231         hasChanges = copyFileBuffers(file.source, file) || hasChanges;
232 
233         if (hasChanges) {
234             // TODO(b/245971639): Avoid notifying if no-one is polling.
235             notifyAll();
236         }
237     }
238 
copyFileBuffers(File src, File dst)239     private boolean copyFileBuffers(File src, File dst) {
240         if (src.writeQueue == null || dst.readQueue == null) {
241             return false;
242         }
243 
244         final int originalCopyCount = Math.min(mTmpBuffer.length,
245             Math.min(src.writeQueue.size(), dst.readQueue.freeSize()));
246 
247         final int allowedCopyCount = RateLimiter.limit(
248             src.outboundLimiter, dst.inboundLimiter, originalCopyCount);
249 
250         if (allowedCopyCount < originalCopyCount) {
251             if (ENABLE_FINE_DEBUG) {
252                 Log.i(mLogTag, logStr("Delaying transfer of "
253                     + (originalCopyCount - allowedCopyCount) + " bytes, original="
254                     + originalCopyCount + ", allowed=" + allowedCopyCount
255                     + ", type=" + src.type));
256             }
257             if (originalCopyCount > 0) {
258                 mHasRateLimitedData = true;
259             }
260             if (allowedCopyCount == 0) {
261                 return false;
262             }
263         }
264 
265         boolean hasChanges = false;
266         if (allowedCopyCount > 0) {
267             if (dst.readQueue.size() == 0 || src.writeQueue.freeSize() == 0) {
268                 hasChanges = true;  // Read queue had no data, or write queue was full.
269             }
270             src.writeQueue.readBytes(mTmpBuffer, 0, allowedCopyCount);
271             dst.readQueue.writeBytes(mTmpBuffer, 0, allowedCopyCount);
272         }
273 
274         if (!dst.isEndOfStream && src.openCount == 0
275                 && src.writeQueue.size() == 0 && dst.readQueue.size() == 0) {
276             dst.isEndOfStream = true;
277             hasChanges = true;
278         }
279 
280         return hasChanges;
281     }
282 
clearInboundRateLimit(FileDescriptor fd)283     public void clearInboundRateLimit(FileDescriptor fd) {
284         setInboundRateLimit(fd, Integer.MAX_VALUE);
285     }
286 
clearOutboundRateLimit(FileDescriptor fd)287     public void clearOutboundRateLimit(FileDescriptor fd) {
288         setOutboundRateLimit(fd, Integer.MAX_VALUE);
289     }
290 
setInboundRateLimit(FileDescriptor fd, int bytesPerSecond)291     public synchronized void setInboundRateLimit(FileDescriptor fd, int bytesPerSecond) {
292         File file = getFileOrNull(fd);
293         if (file != null) {
294             file.inboundLimiter.setBytesPerSecond(bytesPerSecond);
295             maybeTransferData(file);
296         }
297     }
298 
setOutboundRateLimit(FileDescriptor fd, int bytesPerSecond)299     public synchronized void setOutboundRateLimit(FileDescriptor fd, int bytesPerSecond) {
300         File file = getFileOrNull(fd);
301         if (file != null) {
302             file.outboundLimiter.setBytesPerSecond(bytesPerSecond);
303             maybeTransferData(file);
304         }
305     }
306 
socketpair()307     public synchronized ParcelFileDescriptor[] socketpair() throws IOException {
308         int fdNumber1 = getNextFd("socketpair");
309         int fdNumber2 = getNextFd("socketpair");
310 
311         File file1 = new File(FileType.PAIR, mReadQueueSize, mWriteQueueSize);
312         File file2 = new File(FileType.PAIR, mReadQueueSize, mWriteQueueSize);
313 
314         return registerFilePair(fdNumber1, file1, fdNumber2, file2);
315     }
316 
317     @Override
pipe()318     public synchronized ParcelFileDescriptor[] pipe() throws IOException {
319         int fdNumber1 = getNextFd("pipe");
320         int fdNumber2 = getNextFd("pipe");
321 
322         File file1 = new File(FileType.PIPE, mReadQueueSize, 0);
323         File file2 = new File(FileType.PIPE, 0, mWriteQueueSize);
324 
325         return registerFilePair(fdNumber1, file1, fdNumber2, file2);
326     }
327 
registerFilePair( int fdNumber1, File file1, int fdNumber2, File file2)328     private ParcelFileDescriptor[] registerFilePair(
329             int fdNumber1, File file1, int fdNumber2, File file2) {
330         file1.sink = file2;
331         file1.source = file2;
332         file2.sink = file1;
333         file2.source = file1;
334 
335         mFiles.put(fdNumber1, file1);
336         mFiles.put(fdNumber2, file2);
337         return new ParcelFileDescriptor[] {
338             newParcelFileDescriptor(fdNumber1), newParcelFileDescriptor(fdNumber2)};
339     }
340 
341     @Override
getPollInMask()342     public short getPollInMask() {
343         return POLLIN;
344     }
345 
346     @Override
getPollOutMask()347     public short getPollOutMask() {
348         return POLLOUT;
349     }
350 
351     @Override
poll(StructPollfd[] fds, int timeoutMs)352     public synchronized int poll(StructPollfd[] fds, int timeoutMs) throws IOException {
353         if (timeoutMs < 0) {
354             timeoutMs = (int) TimeUnit.HOURS.toMillis(1);  // Make "infinite" equal to 1 hour.
355         }
356 
357         if (fds == null || fds.length > 1000) {
358             throw newIOException("poll", "Invalid fds param");
359         }
360         for (StructPollfd pollFd : fds) {
361             getFile("poll", pollFd.fd);
362         }
363 
364         int waitCallCount = 0;
365         final long deadline = monotonicTimeMillis() + timeoutMs;
366         while (true) {
367             if (mHasRateLimitedData) {
368                 mHasRateLimitedData = false;
369                 for (File file : mFiles.values()) {
370                     if (file.inboundLimiter.getLastRequestReduction() != 0) {
371                         copyFileBuffers(file.source, file);
372                     }
373                     if (file.outboundLimiter.getLastRequestReduction() != 0) {
374                         copyFileBuffers(file, file.sink);
375                     }
376                 }
377             }
378 
379             final int readyCount = calculateReadyCount(fds);
380             if (readyCount > 0) {
381                 if (ENABLE_FINE_DEBUG) {
382                     Log.v(mLogTag, logStr("Poll returns " + readyCount
383                             + " after " + waitCallCount + " wait calls"));
384                 }
385                 return readyCount;
386             }
387 
388             long remainingTimeoutMs = deadline - monotonicTimeMillis();
389             if (remainingTimeoutMs <= 0) {
390                 if (ENABLE_FINE_DEBUG) {
391                     Log.v(mLogTag, logStr("Poll timeout " + timeoutMs
392                             + "ms after " + waitCallCount + " wait calls"));
393                 }
394                 return 0;
395             }
396 
397             if (mHasRateLimitedData) {
398                 remainingTimeoutMs = Math.min(RateLimiter.BUCKET_DURATION_MS, remainingTimeoutMs);
399             }
400 
401             try {
402                 wait(remainingTimeoutMs);
403             } catch (InterruptedException e) {
404                 // Ignore and retry
405             }
406             waitCallCount++;
407         }
408     }
409 
calculateReadyCount(StructPollfd[] fds)410     private int calculateReadyCount(StructPollfd[] fds) {
411         int fdCount = 0;
412         for (StructPollfd pollFd : fds) {
413             pollFd.revents = 0;
414 
415             File file = getFileOrNull(pollFd.fd);
416             if (file == null) {
417                 Log.w(mLogTag, logStr("Ignoring FD concurrently closed by a buggy app: "
418                         + getFileDebugName(pollFd.fd)));
419                 continue;
420             }
421 
422             if (ENABLE_FINE_DEBUG) {
423                 Log.v(mLogTag, logStr("calculateReadyCount fd=" + getFileDebugName(pollFd.fd)
424                         + ", events=" + pollFd.events + ", eof=" + file.isEndOfStream
425                         + ", r=" + (file.readQueue != null ? file.readQueue.size() : -1)
426                         + ", w=" + (file.writeQueue != null ? file.writeQueue.freeSize() : -1)));
427             }
428 
429             if ((pollFd.events & POLLIN) != 0) {
430                 if (file.readQueue != null && file.readQueue.size() != 0) {
431                     pollFd.revents |= POLLIN;
432                 }
433                 if (file.isEndOfStream) {
434                     pollFd.revents |= POLLHUP;
435                 }
436             }
437 
438             if ((pollFd.events & POLLOUT) != 0) {
439                 if (file.type == FileType.PIPE && file.sink.openCount == 0) {
440                     pollFd.revents |= POLLERR;
441                 }
442                 if (file.writeQueue != null && file.writeQueue.freeSize() != 0) {
443                     pollFd.revents |= POLLOUT;
444                 }
445             }
446 
447             if (pollFd.revents != 0) {
448                 fdCount++;
449             }
450         }
451         return fdCount;
452     }
453 
getNextFd(String func)454     private int getNextFd(String func) throws IOException {
455         if (mFileNumberGen > 100000) {
456             throw newIOException(func, "Too many files open");
457         }
458 
459         return mFileNumberGen++;
460     }
461 
newIOException(String func, String message)462     private static IOException newIOException(String func, String message) {
463         return new IOException(message + ", func=" + func);
464     }
465 
checkBoundaries(String func, byte[] buffer, int pos, int len)466     public static void checkBoundaries(String func, byte[] buffer, int pos, int len)
467             throws IOException {
468         if (((buffer.length | pos | len) < 0 || pos > buffer.length - len)) {
469             throw newIOException(func, "Invalid array bounds");
470         }
471     }
472 
newParcelFileDescriptor(int fdNumber)473     private ParcelFileDescriptor newParcelFileDescriptor(int fdNumber) {
474         try {
475             return new ParcelFileDescriptor(newFileDescriptor(fdNumber));
476         } catch (Exception e) {
477             throw new RuntimeException(e);
478         }
479     }
480 
newFileDescriptor(int fdNumber)481     private FileDescriptor newFileDescriptor(int fdNumber) {
482         try {
483             return FD_CONSTRUCTOR.newInstance(Integer.valueOf(fdNumber));
484         } catch (Exception e) {
485             throw new RuntimeException(e);
486         }
487     }
488 
getFileDescriptorNumber(FileDescriptor fd)489     public int getFileDescriptorNumber(FileDescriptor fd) {
490         try {
491             return (Integer) FD_FIELD_DESCRIPTOR.get(fd);
492         } catch (Exception e) {
493             throw new RuntimeException(e);
494         }
495     }
496 
setFileDescriptorNumber(FileDescriptor fd, int fdNumber)497     private void setFileDescriptorNumber(FileDescriptor fd, int fdNumber) {
498         try {
499             FD_FIELD_DESCRIPTOR.set(fd, Integer.valueOf(fdNumber));
500         } catch (Exception e) {
501             throw new RuntimeException(e);
502         }
503     }
504 
logStr(String message)505     private String logStr(String message) {
506         return "[FakeOs " + (monotonicTimeMillis() - mStartTime) + "] " + message;
507     }
508 
509     private class File {
510         final FileType type;
511         final CircularByteBuffer readQueue;
512         final CircularByteBuffer writeQueue;
513         final RateLimiter inboundLimiter = new RateLimiter(FakeOsAccess.this, Integer.MAX_VALUE);
514         final RateLimiter outboundLimiter = new RateLimiter(FakeOsAccess.this, Integer.MAX_VALUE);
515         String name;
516         int openCount = 1;
517         boolean isBlocking = true;
518         File sink;
519         File source;
520         boolean isEndOfStream;
521 
File(FileType type, int readQueueSize, int writeQueueSize)522         File(FileType type, int readQueueSize, int writeQueueSize) {
523             this.type = type;
524             readQueue = (readQueueSize > 0 ? new CircularByteBuffer(readQueueSize) : null);
525             writeQueue = (writeQueueSize > 0 ? new CircularByteBuffer(writeQueueSize) : null);
526         }
527 
decreaseRefCount()528         void decreaseRefCount() {
529             if (openCount <= 0) {
530                 throw new IllegalStateException();
531             }
532             openCount--;
533         }
534 
checkNonBlocking(String func)535         void checkNonBlocking(String func) throws IOException {
536             if (isBlocking) {
537                 throw newIOException(func, "File in blocking mode");
538             }
539         }
540     }
541 
542     static {
543         try {
544             FD_CONSTRUCTOR = FileDescriptor.class.getDeclaredConstructor(int.class);
545             FD_CONSTRUCTOR.setAccessible(true);
546 
547             Field descriptorIntField;
548             try {
549                 descriptorIntField = FileDescriptor.class.getDeclaredField("descriptor");
550             } catch (NoSuchFieldException e) {
551                 descriptorIntField = FileDescriptor.class.getDeclaredField("fd");
552             }
553             FD_FIELD_DESCRIPTOR = descriptorIntField;
554             FD_FIELD_DESCRIPTOR.setAccessible(true);
555 
556             PFD_FIELD_DESCRIPTOR = ParcelFileDescriptor.class.getDeclaredField("mFd");
557             PFD_FIELD_DESCRIPTOR.setAccessible(true);
558 
559             PFD_FIELD_GUARD = ParcelFileDescriptor.class.getDeclaredField("mGuard");
560             PFD_FIELD_GUARD.setAccessible(true);
561 
562             CLOSE_GUARD_METHOD_CLOSE = Class.forName("dalvik.system.CloseGuard")
563                 .getDeclaredMethod("close");
564         } catch (Exception e) {
565             throw new RuntimeException(e);
566         }
567     }
568 }
569