1 #include <stdio.h>
2 #include <string.h>
3 #include <core.h>
4 #include "pxe.h"
5 
6 /* DNS CLASS values we care about */
7 #define CLASS_IN	1
8 
9 /* DNS TYPE values we care about */
10 #define TYPE_A		1
11 #define TYPE_CNAME	5
12 
13 /*
14  * The DNS header structure
15  */
16 struct dnshdr {
17     uint16_t id;
18     uint16_t flags;
19     /* number of entries in the question section */
20     uint16_t qdcount;
21     /* number of resource records in the answer section */
22     uint16_t ancount;
23     /* number of name server resource records in the authority records section*/
24     uint16_t nscount;
25     /* number of resource records in the additional records section */
26     uint16_t arcount;
27 } __attribute__ ((packed));
28 
29 /*
30  * The DNS query structure
31  */
32 struct dnsquery {
33     uint16_t qtype;
34     uint16_t qclass;
35 } __attribute__ ((packed));
36 
37 /*
38  * The DNS Resource recodes structure
39  */
40 struct dnsrr {
41     uint16_t type;
42     uint16_t class;
43     uint32_t ttl;
44     uint16_t rdlength;   /* The lenght of this rr data */
45     char     rdata[];
46 } __attribute__ ((packed));
47 
48 
49 #define DNS_PORT	htons(53)               /* Default DNS port */
50 #define DNS_MAX_SERVERS 4		/* Max no of DNS servers */
51 
52 uint32_t dns_server[DNS_MAX_SERVERS] = {0, };
53 
54 
55 /*
56  * Turn a string in _src_ into a DNS "label set" in _dst_; returns the
57  * number of dots encountered. On return, *dst is updated.
58  */
dns_mangle(char ** dst,const char * p)59 int dns_mangle(char **dst, const char *p)
60 {
61     char *q = *dst;
62     char *count_ptr;
63     char c;
64     int dots = 0;
65 
66     count_ptr = q;
67     *q++ = 0;
68 
69     while (1) {
70         c = *p++;
71         if (c == 0 || c == ':' || c == '/')
72             break;
73         if (c == '.') {
74             dots++;
75             count_ptr = q;
76             *q++ = 0;
77             continue;
78         }
79 
80         *count_ptr += 1;
81         *q++ = c;
82     }
83 
84     if (*count_ptr)
85         *q++ = 0;
86 
87     /* update the strings */
88     *dst = q;
89     return dots;
90 }
91 
92 
93 /*
94  * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s2_
95  * is allowed pointers relative to a packet in buf.
96  *
97  */
dns_compare(const void * s1,const void * s2,const void * buf)98 static bool dns_compare(const void *s1, const void *s2, const void *buf)
99 {
100     const uint8_t *q = s1;
101     const uint8_t *p = s2;
102     unsigned int c0, c1;
103 
104     while (1) {
105 	c0 = p[0];
106         if (c0 >= 0xc0) {
107 	    /* Follow pointer */
108 	    c1 = p[1];
109 	    p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
110 	} else if (c0) {
111 	    c0++;		/* Include the length byte */
112 	    if (memcmp(q, p, c0))
113 		return false;
114 	    q += c0;
115 	    p += c0;
116 	} else {
117 	    return *q == 0;
118 	}
119     }
120 }
121 
122 /*
123  * Copy a DNS label into a buffer, considering the possibility that we might
124  * have to follow pointers relative to "buf".
125  * Returns a pointer to the first free byte *after* the terminal null.
126  */
dns_copylabel(void * dst,const void * src,const void * buf)127 static void *dns_copylabel(void *dst, const void *src, const void *buf)
128 {
129     uint8_t *q = dst;
130     const uint8_t *p = src;
131     unsigned int c0, c1;
132 
133     while (1) {
134 	c0 = p[0];
135         if (c0 >= 0xc0) {
136 	    /* Follow pointer */
137 	    c1 = p[1];
138 	    p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
139 	} else if (c0) {
140 	    c0++;		/* Include the length byte */
141 	    memcpy(q, p, c0);
142 	    p += c0;
143 	    q += c0;
144 	} else {
145 	    *q++ = 0;
146 	    return q;
147 	}
148     }
149 }
150 
151 /*
152  * Skip past a DNS label set in DS:SI
153  */
dns_skiplabel(char * label)154 static char *dns_skiplabel(char *label)
155 {
156     uint8_t c;
157 
158     while (1) {
159         c = *label++;
160         if (c >= 0xc0)
161             return ++label; /* pointer is two bytes */
162         if (c == 0)
163             return label;
164         label += c;
165     }
166 }
167 
168 extern const uint8_t TimeoutTable[];
169 extern uint16_t get_port(void);
170 extern void free_port(uint16_t port);
171 
172 /*
173  * parse the ip_str and return the ip address with *res.
174  * return true if the whole string was consumed and the result
175  * was valid.
176  *
177  */
parse_dotquad(const char * ip_str,uint32_t * res)178 static bool parse_dotquad(const char *ip_str, uint32_t *res)
179 {
180     const char *p = ip_str;
181     uint8_t part = 0;
182     uint32_t ip = 0;
183     int i;
184 
185     for (i = 0; i < 4; i++) {
186         while (is_digit(*p)) {
187             part = part * 10 + *p - '0';
188             p++;
189         }
190         if (i != 3 && *p != '.')
191             return false;
192 
193         ip = (ip << 8) | part;
194         part = 0;
195         p++;
196     }
197     p--;
198 
199     *res = htonl(ip);
200     return *p == '\0';
201 }
202 
203 /*
204  * Actual resolver function
205  * Points to a null-terminated or :-terminated string in _name_
206  * and returns the ip addr in _ip_ if it exists and can be found.
207  * If _ip_ = 0 on exit, the lookup failed. _name_ will be updated
208  *
209  * XXX: probably need some caching here.
210  */
dns_resolv(const char * name)211 __export uint32_t dns_resolv(const char *name)
212 {
213     static char __lowmem DNSSendBuf[PKTBUF_SIZE];
214     static char __lowmem DNSRecvBuf[PKTBUF_SIZE];
215     char *p;
216     int err;
217     int dots;
218     int same;
219     int rd_len;
220     int ques, reps;    /* number of questions and replies */
221     uint8_t timeout;
222     const uint8_t *timeout_ptr = TimeoutTable;
223     uint32_t oldtime;
224     uint32_t srv;
225     uint32_t *srv_ptr;
226     struct dnshdr *hd1 = (struct dnshdr *)DNSSendBuf;
227     struct dnshdr *hd2 = (struct dnshdr *)DNSRecvBuf;
228     struct dnsquery *query;
229     struct dnsrr *rr;
230     static __lowmem struct s_PXENV_UDP_WRITE udp_write;
231     static __lowmem struct s_PXENV_UDP_READ  udp_read;
232     uint16_t local_port;
233     uint32_t result = 0;
234 
235     /*
236      * Return failure on an empty input... this can happen during
237      * some types of URL parsing, and this is the easiest place to
238      * check for it.
239      */
240     if (!name || !*name)
241 	return 0;
242 
243     /* If it is a valid dot quad, just return that value */
244     if (parse_dotquad(name, &result))
245 	return result;
246 
247     /* Make sure we have at least one valid DNS server */
248     if (!dns_server[0])
249 	return 0;
250 
251     /* Get a local port number */
252     local_port = get_port();
253 
254     /* First, fill the DNS header struct */
255     hd1->id++;                      /* New query ID */
256     hd1->flags   = htons(0x0100);   /* Recursion requested */
257     hd1->qdcount = htons(1);        /* One question */
258     hd1->ancount = 0;               /* No answers */
259     hd1->nscount = 0;               /* No NS */
260     hd1->arcount = 0;               /* No AR */
261 
262     p = DNSSendBuf + sizeof(struct dnshdr);
263     dots = dns_mangle(&p, name);   /* store the CNAME */
264 
265     if (!dots) {
266         p--; /* Remove final null */
267         /* Uncompressed DNS label set so it ends in null */
268         p = stpcpy(p, LocalDomain);
269     }
270 
271     /* Fill the DNS query packet */
272     query = (struct dnsquery *)p;
273     query->qtype  = htons(TYPE_A);
274     query->qclass = htons(CLASS_IN);
275     p += sizeof(struct dnsquery);
276 
277     /* Now send it to name server */
278     timeout_ptr = TimeoutTable;
279     timeout = *timeout_ptr++;
280     srv_ptr = dns_server;
281     while (timeout) {
282 	srv = *srv_ptr++;
283 	if (!srv) {
284 	    srv_ptr = dns_server;
285 	    srv = *srv_ptr++;
286 	}
287 
288         udp_write.status      = 0;
289         udp_write.ip          = srv;
290         udp_write.gw          = gateway(srv);
291         udp_write.src_port    = local_port;
292         udp_write.dst_port    = DNS_PORT;
293         udp_write.buffer_size = p - DNSSendBuf;
294         udp_write.buffer      = FAR_PTR(DNSSendBuf);
295         err = pxe_call(PXENV_UDP_WRITE, &udp_write);
296         if (err || udp_write.status)
297             continue;
298 
299         oldtime = jiffies();
300 	do {
301 	    if (jiffies() - oldtime >= timeout)
302 		goto again;
303 
304             udp_read.status      = 0;
305             udp_read.src_ip      = srv;
306             udp_read.dest_ip     = IPInfo.myip;
307             udp_read.s_port      = DNS_PORT;
308             udp_read.d_port      = local_port;
309             udp_read.buffer_size = PKTBUF_SIZE;
310             udp_read.buffer      = FAR_PTR(DNSRecvBuf);
311             err = pxe_call(PXENV_UDP_READ, &udp_read);
312 	} while (err || udp_read.status || hd2->id != hd1->id);
313 
314         if ((hd2->flags ^ 0x80) & htons(0xf80f))
315             goto badness;
316 
317         ques = htons(hd2->qdcount);   /* Questions */
318         reps = htons(hd2->ancount);   /* Replies   */
319         p = DNSRecvBuf + sizeof(struct dnshdr);
320         while (ques--) {
321             p = dns_skiplabel(p); /* Skip name */
322             p += 4;               /* Skip question trailer */
323         }
324 
325         /* Parse the replies */
326         while (reps--) {
327             same = dns_compare(DNSSendBuf + sizeof(struct dnshdr),
328 			       p, DNSRecvBuf);
329             p = dns_skiplabel(p);
330             rr = (struct dnsrr *)p;
331             rd_len = ntohs(rr->rdlength);
332             if (same && ntohs(rr->class) == CLASS_IN) {
333 		switch (ntohs(rr->type)) {
334 		case TYPE_A:
335 		    if (rd_len == 4) {
336 			result = *(uint32_t *)rr->rdata;
337 			goto done;
338 		    }
339 		    break;
340 		case TYPE_CNAME:
341 		    dns_copylabel(DNSSendBuf + sizeof(struct dnshdr),
342 				  rr->rdata, DNSRecvBuf);
343 		    /*
344 		     * We should probably rescan the packet from the top
345 		     * here, and technically we might have to send a whole
346 		     * new request here...
347 		     */
348 		    break;
349 		default:
350 		    break;
351 		}
352 	    }
353 
354             /* not the one we want, try next */
355             p += sizeof(struct dnsrr) + rd_len;
356         }
357 
358     badness:
359         /*
360          *
361          ; We got back no data from this server.
362          ; Unfortunately, for a recursive, non-authoritative
363          ; query there is no such thing as an NXDOMAIN reply,
364          ; which technically means we can't draw any
365          ; conclusions.  However, in practice that means the
366          ; domain doesn't exist.  If this turns out to be a
367          ; problem, we may want to add code to go through all
368          ; the servers before giving up.
369 
370          ; If the DNS server wasn't capable of recursion, and
371          ; isn't capable of giving us an authoritative reply
372          ; (i.e. neither AA or RA set), then at least try a
373          ; different setver...
374         */
375         if (hd2->flags == htons(0x480))
376             continue;
377 
378         break; /* failed */
379 
380     again:
381 	continue;
382     }
383 
384 done:
385     free_port(local_port);	/* Return port number to the free pool */
386 
387     return result;
388 }
389