1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net;
18 
19 import static android.net.NetworkUtils.getDnsNetwork;
20 import static android.net.NetworkUtils.resNetworkCancel;
21 import static android.net.NetworkUtils.resNetworkQuery;
22 import static android.net.NetworkUtils.resNetworkResult;
23 import static android.net.NetworkUtils.resNetworkSend;
24 import static android.net.util.DnsUtils.haveIpv4;
25 import static android.net.util.DnsUtils.haveIpv6;
26 import static android.net.util.DnsUtils.rfc6724Sort;
27 import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_ERROR;
28 import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
29 import static android.system.OsConstants.ENONET;
30 
31 import android.annotation.CallbackExecutor;
32 import android.annotation.IntDef;
33 import android.annotation.NonNull;
34 import android.annotation.Nullable;
35 import android.os.CancellationSignal;
36 import android.os.Looper;
37 import android.os.MessageQueue;
38 import android.system.ErrnoException;
39 import android.util.Log;
40 
41 import java.io.FileDescriptor;
42 import java.lang.annotation.Retention;
43 import java.lang.annotation.RetentionPolicy;
44 import java.net.InetAddress;
45 import java.net.UnknownHostException;
46 import java.util.ArrayList;
47 import java.util.List;
48 import java.util.concurrent.Executor;
49 
50 /**
51  * Dns resolver class for asynchronous dns querying
52  *
53  * Note that if a client sends a query with more than 1 record in the question section but
54  * the remote dns server does not support this, it may not respond at all, leading to a timeout.
55  *
56  */
57 public final class DnsResolver {
58     private static final String TAG = "DnsResolver";
59     private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
60     private static final int MAXPACKET = 8 * 1024;
61     private static final int SLEEP_TIME_MS = 2;
62 
63     @IntDef(prefix = { "CLASS_" }, value = {
64             CLASS_IN
65     })
66     @Retention(RetentionPolicy.SOURCE)
67     @interface QueryClass {}
68     public static final int CLASS_IN = 1;
69 
70     @IntDef(prefix = { "TYPE_" },  value = {
71             TYPE_A,
72             TYPE_AAAA
73     })
74     @Retention(RetentionPolicy.SOURCE)
75     @interface QueryType {}
76     public static final int TYPE_A = 1;
77     public static final int TYPE_AAAA = 28;
78 
79     @IntDef(prefix = { "FLAG_" }, value = {
80             FLAG_EMPTY,
81             FLAG_NO_RETRY,
82             FLAG_NO_CACHE_STORE,
83             FLAG_NO_CACHE_LOOKUP
84     })
85     @Retention(RetentionPolicy.SOURCE)
86     @interface QueryFlag {}
87     public static final int FLAG_EMPTY = 0;
88     public static final int FLAG_NO_RETRY = 1 << 0;
89     public static final int FLAG_NO_CACHE_STORE = 1 << 1;
90     public static final int FLAG_NO_CACHE_LOOKUP = 1 << 2;
91 
92     @IntDef(prefix = { "ERROR_" }, value = {
93             ERROR_PARSE,
94             ERROR_SYSTEM
95     })
96     @Retention(RetentionPolicy.SOURCE)
97     @interface DnsError {}
98     /**
99      * Indicates that there was an error parsing the response the query.
100      * The cause of this error is available via getCause() and is a ParseException.
101      */
102     public static final int ERROR_PARSE = 0;
103     /**
104      * Indicates that there was an error sending the query.
105      * The cause of this error is available via getCause() and is an ErrnoException.
106      */
107     public static final int ERROR_SYSTEM = 1;
108 
109     private static final int NETID_UNSET = 0;
110 
111     private static final DnsResolver sInstance = new DnsResolver();
112 
113     /**
114      * Get instance for DnsResolver
115      */
getInstance()116     public static @NonNull DnsResolver getInstance() {
117         return sInstance;
118     }
119 
DnsResolver()120     private DnsResolver() {}
121 
122     /**
123      * Base interface for answer callbacks
124      *
125      * @param <T> The type of the answer
126      */
127     public interface Callback<T> {
128         /**
129          * Success response to
130          * {@link android.net.DnsResolver#query query()} or
131          * {@link android.net.DnsResolver#rawQuery rawQuery()}.
132          *
133          * Invoked when the answer to a query was successfully parsed.
134          *
135          * @param answer <T> answer to the query.
136          * @param rcode The response code in the DNS response.
137          *
138          * {@see android.net.DnsResolver#query query()}
139          */
onAnswer(@onNull T answer, int rcode)140         void onAnswer(@NonNull T answer, int rcode);
141         /**
142          * Error response to
143          * {@link android.net.DnsResolver#query query()} or
144          * {@link android.net.DnsResolver#rawQuery rawQuery()}.
145          *
146          * Invoked when there is no valid answer to
147          * {@link android.net.DnsResolver#query query()}
148          * {@link android.net.DnsResolver#rawQuery rawQuery()}.
149          *
150          * @param error a {@link DnsException} object with additional
151          *    detail regarding the failure
152          */
onError(@onNull DnsException error)153         void onError(@NonNull DnsException error);
154     }
155 
156     /**
157      * Class to represent DNS error
158      */
159     public static class DnsException extends Exception {
160        /**
161         * DNS error code as one of the ERROR_* constants
162         */
163         @DnsError public final int code;
164 
DnsException(@nsError int code, @Nullable Throwable cause)165         DnsException(@DnsError int code, @Nullable Throwable cause) {
166             super(cause);
167             this.code = code;
168         }
169     }
170 
171     /**
172      * Send a raw DNS query.
173      * The answer will be provided asynchronously through the provided {@link Callback}.
174      *
175      * @param network {@link Network} specifying which network to query on.
176      *         {@code null} for query on default network.
177      * @param query blob message to query
178      * @param flags flags as a combination of the FLAGS_* constants
179      * @param executor The {@link Executor} that the callback should be executed on.
180      * @param cancellationSignal used by the caller to signal if the query should be
181      *    cancelled. May be {@code null}.
182      * @param callback a {@link Callback} which will be called to notify the caller
183      *    of the result of dns query.
184      */
rawQuery(@ullable Network network, @NonNull byte[] query, @QueryFlag int flags, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull Callback<? super byte[]> callback)185     public void rawQuery(@Nullable Network network, @NonNull byte[] query, @QueryFlag int flags,
186             @NonNull @CallbackExecutor Executor executor,
187             @Nullable CancellationSignal cancellationSignal,
188             @NonNull Callback<? super byte[]> callback) {
189         if (cancellationSignal != null && cancellationSignal.isCanceled()) {
190             return;
191         }
192         final Object lock = new Object();
193         final FileDescriptor queryfd;
194         try {
195             queryfd = resNetworkSend((network != null)
196                     ? network.getNetIdForResolv() : NETID_UNSET, query, query.length, flags);
197         } catch (ErrnoException e) {
198             executor.execute(() -> callback.onError(new DnsException(ERROR_SYSTEM, e)));
199             return;
200         }
201 
202         synchronized (lock)  {
203             registerFDListener(executor, queryfd, callback, cancellationSignal, lock);
204             if (cancellationSignal == null) return;
205             addCancellationSignal(cancellationSignal, queryfd, lock);
206         }
207     }
208 
209     /**
210      * Send a DNS query with the specified name, class and query type.
211      * The answer will be provided asynchronously through the provided {@link Callback}.
212      *
213      * @param network {@link Network} specifying which network to query on.
214      *         {@code null} for query on default network.
215      * @param domain domain name to query
216      * @param nsClass dns class as one of the CLASS_* constants
217      * @param nsType dns resource record (RR) type as one of the TYPE_* constants
218      * @param flags flags as a combination of the FLAGS_* constants
219      * @param executor The {@link Executor} that the callback should be executed on.
220      * @param cancellationSignal used by the caller to signal if the query should be
221      *    cancelled. May be {@code null}.
222      * @param callback a {@link Callback} which will be called to notify the caller
223      *    of the result of dns query.
224      */
rawQuery(@ullable Network network, @NonNull String domain, @QueryClass int nsClass, @QueryType int nsType, @QueryFlag int flags, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull Callback<? super byte[]> callback)225     public void rawQuery(@Nullable Network network, @NonNull String domain,
226             @QueryClass int nsClass, @QueryType int nsType, @QueryFlag int flags,
227             @NonNull @CallbackExecutor Executor executor,
228             @Nullable CancellationSignal cancellationSignal,
229             @NonNull Callback<? super byte[]> callback) {
230         if (cancellationSignal != null && cancellationSignal.isCanceled()) {
231             return;
232         }
233         final Object lock = new Object();
234         final FileDescriptor queryfd;
235         try {
236             queryfd = resNetworkQuery((network != null)
237                     ? network.getNetIdForResolv() : NETID_UNSET, domain, nsClass, nsType, flags);
238         } catch (ErrnoException e) {
239             executor.execute(() -> callback.onError(new DnsException(ERROR_SYSTEM, e)));
240             return;
241         }
242         synchronized (lock)  {
243             registerFDListener(executor, queryfd, callback, cancellationSignal, lock);
244             if (cancellationSignal == null) return;
245             addCancellationSignal(cancellationSignal, queryfd, lock);
246         }
247     }
248 
249     private class InetAddressAnswerAccumulator implements Callback<byte[]> {
250         private final List<InetAddress> mAllAnswers;
251         private final Network mNetwork;
252         private int mRcode;
253         private DnsException mDnsException;
254         private final Callback<? super List<InetAddress>> mUserCallback;
255         private final int mTargetAnswerCount;
256         private int mReceivedAnswerCount = 0;
257 
InetAddressAnswerAccumulator(@onNull Network network, int size, @NonNull Callback<? super List<InetAddress>> callback)258         InetAddressAnswerAccumulator(@NonNull Network network, int size,
259                 @NonNull Callback<? super List<InetAddress>> callback) {
260             mNetwork = network;
261             mTargetAnswerCount = size;
262             mAllAnswers = new ArrayList<>();
263             mUserCallback = callback;
264         }
265 
maybeReportError()266         private boolean maybeReportError() {
267             if (mRcode != 0) {
268                 mUserCallback.onAnswer(mAllAnswers, mRcode);
269                 return true;
270             }
271             if (mDnsException != null) {
272                 mUserCallback.onError(mDnsException);
273                 return true;
274             }
275             return false;
276         }
277 
maybeReportAnswer()278         private void maybeReportAnswer() {
279             if (++mReceivedAnswerCount != mTargetAnswerCount) return;
280             if (mAllAnswers.isEmpty() && maybeReportError()) return;
281             mUserCallback.onAnswer(rfc6724Sort(mNetwork, mAllAnswers), mRcode);
282         }
283 
284         @Override
onAnswer(@onNull byte[] answer, int rcode)285         public void onAnswer(@NonNull byte[] answer, int rcode) {
286             // If at least one query succeeded, return an rcode of 0.
287             // Otherwise, arbitrarily return the first rcode received.
288             if (mReceivedAnswerCount == 0 || rcode == 0) {
289                 mRcode = rcode;
290             }
291             try {
292                 mAllAnswers.addAll(new DnsAddressAnswer(answer).getAddresses());
293             } catch (ParseException e) {
294                 mDnsException = new DnsException(ERROR_PARSE, e);
295             }
296             maybeReportAnswer();
297         }
298 
299         @Override
onError(@onNull DnsException error)300         public void onError(@NonNull DnsException error) {
301             mDnsException = error;
302             maybeReportAnswer();
303         }
304     }
305 
306     /**
307      * Send a DNS query with the specified name on a network with both IPv4 and IPv6,
308      * get back a set of InetAddresses with rfc6724 sorting style asynchronously.
309      *
310      * This method will examine the connection ability on given network, and query IPv4
311      * and IPv6 if connection is available.
312      *
313      * If at least one query succeeded with valid answer, rcode will be 0
314      *
315      * The answer will be provided asynchronously through the provided {@link Callback}.
316      *
317      * @param network {@link Network} specifying which network to query on.
318      *         {@code null} for query on default network.
319      * @param domain domain name to query
320      * @param flags flags as a combination of the FLAGS_* constants
321      * @param executor The {@link Executor} that the callback should be executed on.
322      * @param cancellationSignal used by the caller to signal if the query should be
323      *    cancelled. May be {@code null}.
324      * @param callback a {@link Callback} which will be called to notify the
325      *    caller of the result of dns query.
326      */
query(@ullable Network network, @NonNull String domain, @QueryFlag int flags, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull Callback<? super List<InetAddress>> callback)327     public void query(@Nullable Network network, @NonNull String domain, @QueryFlag int flags,
328             @NonNull @CallbackExecutor Executor executor,
329             @Nullable CancellationSignal cancellationSignal,
330             @NonNull Callback<? super List<InetAddress>> callback) {
331         if (cancellationSignal != null && cancellationSignal.isCanceled()) {
332             return;
333         }
334         final Object lock = new Object();
335         final Network queryNetwork;
336         try {
337             queryNetwork = (network != null) ? network : getDnsNetwork();
338         } catch (ErrnoException e) {
339             executor.execute(() -> callback.onError(new DnsException(ERROR_SYSTEM, e)));
340             return;
341         }
342         final boolean queryIpv6 = haveIpv6(queryNetwork);
343         final boolean queryIpv4 = haveIpv4(queryNetwork);
344 
345         // This can only happen if queryIpv4 and queryIpv6 are both false.
346         // This almost certainly means that queryNetwork does not exist or no longer exists.
347         if (!queryIpv6 && !queryIpv4) {
348             executor.execute(() -> callback.onError(
349                     new DnsException(ERROR_SYSTEM, new ErrnoException("resNetworkQuery", ENONET))));
350             return;
351         }
352 
353         final FileDescriptor v4fd;
354         final FileDescriptor v6fd;
355 
356         int queryCount = 0;
357 
358         if (queryIpv6) {
359             try {
360                 v6fd = resNetworkQuery(queryNetwork.getNetIdForResolv(), domain, CLASS_IN,
361                         TYPE_AAAA, flags);
362             } catch (ErrnoException e) {
363                 executor.execute(() -> callback.onError(new DnsException(ERROR_SYSTEM, e)));
364                 return;
365             }
366             queryCount++;
367         } else v6fd = null;
368 
369         // Avoiding gateways drop packets if queries are sent too close together
370         try {
371             Thread.sleep(SLEEP_TIME_MS);
372         } catch (InterruptedException ex) {
373             Thread.currentThread().interrupt();
374         }
375 
376         if (queryIpv4) {
377             try {
378                 v4fd = resNetworkQuery(queryNetwork.getNetIdForResolv(), domain, CLASS_IN, TYPE_A,
379                         flags);
380             } catch (ErrnoException e) {
381                 if (queryIpv6) resNetworkCancel(v6fd);  // Closes fd, marks it invalid.
382                 executor.execute(() -> callback.onError(new DnsException(ERROR_SYSTEM, e)));
383                 return;
384             }
385             queryCount++;
386         } else v4fd = null;
387 
388         final InetAddressAnswerAccumulator accumulator =
389                 new InetAddressAnswerAccumulator(queryNetwork, queryCount, callback);
390 
391         synchronized (lock)  {
392             if (queryIpv6) {
393                 registerFDListener(executor, v6fd, accumulator, cancellationSignal, lock);
394             }
395             if (queryIpv4) {
396                 registerFDListener(executor, v4fd, accumulator, cancellationSignal, lock);
397             }
398             if (cancellationSignal == null) return;
399             cancellationSignal.setOnCancelListener(() -> {
400                 synchronized (lock)  {
401                     if (queryIpv4) cancelQuery(v4fd);
402                     if (queryIpv6) cancelQuery(v6fd);
403                 }
404             });
405         }
406     }
407 
408     /**
409      * Send a DNS query with the specified name and query type, get back a set of
410      * InetAddresses with rfc6724 sorting style asynchronously.
411      *
412      * The answer will be provided asynchronously through the provided {@link Callback}.
413      *
414      * @param network {@link Network} specifying which network to query on.
415      *         {@code null} for query on default network.
416      * @param domain domain name to query
417      * @param nsType dns resource record (RR) type as one of the TYPE_* constants
418      * @param flags flags as a combination of the FLAGS_* constants
419      * @param executor The {@link Executor} that the callback should be executed on.
420      * @param cancellationSignal used by the caller to signal if the query should be
421      *    cancelled. May be {@code null}.
422      * @param callback a {@link Callback} which will be called to notify the caller
423      *    of the result of dns query.
424      */
query(@ullable Network network, @NonNull String domain, @QueryType int nsType, @QueryFlag int flags, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull Callback<? super List<InetAddress>> callback)425     public void query(@Nullable Network network, @NonNull String domain,
426             @QueryType int nsType, @QueryFlag int flags,
427             @NonNull @CallbackExecutor Executor executor,
428             @Nullable CancellationSignal cancellationSignal,
429             @NonNull Callback<? super List<InetAddress>> callback) {
430         if (cancellationSignal != null && cancellationSignal.isCanceled()) {
431             return;
432         }
433         final Object lock = new Object();
434         final FileDescriptor queryfd;
435         final Network queryNetwork;
436         try {
437             queryNetwork = (network != null) ? network : getDnsNetwork();
438             queryfd = resNetworkQuery(queryNetwork.getNetIdForResolv(), domain, CLASS_IN, nsType,
439                     flags);
440         } catch (ErrnoException e) {
441             executor.execute(() -> callback.onError(new DnsException(ERROR_SYSTEM, e)));
442             return;
443         }
444         final InetAddressAnswerAccumulator accumulator =
445                 new InetAddressAnswerAccumulator(queryNetwork, 1, callback);
446         synchronized (lock)  {
447             registerFDListener(executor, queryfd, accumulator, cancellationSignal, lock);
448             if (cancellationSignal == null) return;
449             addCancellationSignal(cancellationSignal, queryfd, lock);
450         }
451     }
452 
453     /**
454      * Class to retrieve DNS response
455      *
456      * @hide
457      */
458     public static final class DnsResponse {
459         public final @NonNull byte[] answerbuf;
460         public final int rcode;
DnsResponse(@onNull byte[] answerbuf, int rcode)461         public DnsResponse(@NonNull byte[] answerbuf, int rcode) {
462             this.answerbuf = answerbuf;
463             this.rcode = rcode;
464         }
465     }
466 
registerFDListener(@onNull Executor executor, @NonNull FileDescriptor queryfd, @NonNull Callback<? super byte[]> answerCallback, @Nullable CancellationSignal cancellationSignal, @NonNull Object lock)467     private void registerFDListener(@NonNull Executor executor,
468             @NonNull FileDescriptor queryfd, @NonNull Callback<? super byte[]> answerCallback,
469             @Nullable CancellationSignal cancellationSignal, @NonNull Object lock) {
470         final MessageQueue mainThreadMessageQueue = Looper.getMainLooper().getQueue();
471         mainThreadMessageQueue.addOnFileDescriptorEventListener(
472                 queryfd,
473                 FD_EVENTS,
474                 (fd, events) -> {
475                     // b/134310704
476                     // Unregister fd event listener before resNetworkResult is called to prevent
477                     // race condition caused by fd reused.
478                     // For example when querying v4 and v6, it's possible that the first query ends
479                     // and the fd is closed before the second request starts, which might return
480                     // the same fd for the second request. By that time, the looper must have
481                     // unregistered the fd, otherwise another event listener can't be registered.
482                     mainThreadMessageQueue.removeOnFileDescriptorEventListener(fd);
483 
484                     executor.execute(() -> {
485                         DnsResponse resp = null;
486                         ErrnoException exception = null;
487                         synchronized (lock) {
488                             if (cancellationSignal != null && cancellationSignal.isCanceled()) {
489                                 return;
490                             }
491                             try {
492                                 resp = resNetworkResult(fd);  // Closes fd, marks it invalid.
493                             } catch (ErrnoException e) {
494                                 Log.e(TAG, "resNetworkResult:" + e.toString());
495                                 exception = e;
496                             }
497                         }
498                         if (exception != null) {
499                             answerCallback.onError(new DnsException(ERROR_SYSTEM, exception));
500                             return;
501                         }
502                         answerCallback.onAnswer(resp.answerbuf, resp.rcode);
503                     });
504 
505                     // The file descriptor has already been unregistered, so it does not really
506                     // matter what is returned here. In spirit 0 (meaning "unregister this FD")
507                     // is still the closest to what the looper needs to do. When returning 0,
508                     // Looper knows to ignore the fd if it has already been unregistered.
509                     return 0;
510                 });
511     }
512 
cancelQuery(@onNull FileDescriptor queryfd)513     private void cancelQuery(@NonNull FileDescriptor queryfd) {
514         if (!queryfd.valid()) return;
515         Looper.getMainLooper().getQueue().removeOnFileDescriptorEventListener(queryfd);
516         resNetworkCancel(queryfd);  // Closes fd, marks it invalid.
517     }
518 
addCancellationSignal(@onNull CancellationSignal cancellationSignal, @NonNull FileDescriptor queryfd, @NonNull Object lock)519     private void addCancellationSignal(@NonNull CancellationSignal cancellationSignal,
520             @NonNull FileDescriptor queryfd, @NonNull Object lock) {
521         cancellationSignal.setOnCancelListener(() -> {
522             synchronized (lock)  {
523                 cancelQuery(queryfd);
524             }
525         });
526     }
527 
528     private static class DnsAddressAnswer extends DnsPacket {
529         private static final String TAG = "DnsResolver.DnsAddressAnswer";
530         private static final boolean DBG = false;
531 
532         private final int mQueryType;
533 
DnsAddressAnswer(@onNull byte[] data)534         DnsAddressAnswer(@NonNull byte[] data) throws ParseException {
535             super(data);
536             if ((mHeader.flags & (1 << 15)) == 0) {
537                 throw new ParseException("Not an answer packet");
538             }
539             if (mHeader.getRecordCount(QDSECTION) == 0) {
540                 throw new ParseException("No question found");
541             }
542             // Expect only one question in question section.
543             mQueryType = mRecords[QDSECTION].get(0).nsType;
544         }
545 
getAddresses()546         public @NonNull List<InetAddress> getAddresses() {
547             final List<InetAddress> results = new ArrayList<InetAddress>();
548             if (mHeader.getRecordCount(ANSECTION) == 0) return results;
549 
550             for (final DnsRecord ansSec : mRecords[ANSECTION]) {
551                 // Only support A and AAAA, also ignore answers if query type != answer type.
552                 int nsType = ansSec.nsType;
553                 if (nsType != mQueryType || (nsType != TYPE_A && nsType != TYPE_AAAA)) {
554                     continue;
555                 }
556                 try {
557                     results.add(InetAddress.getByAddress(ansSec.getRR()));
558                 } catch (UnknownHostException e) {
559                     if (DBG) {
560                         Log.w(TAG, "rr to address fail");
561                     }
562                 }
563             }
564             return results;
565         }
566     }
567 
568 }
569