1 #include <stdio.h>
2 #include <string.h>
3 #include <core.h>
4 #include "pxe.h"
5
6 /* DNS CLASS values we care about */
7 #define CLASS_IN 1
8
9 /* DNS TYPE values we care about */
10 #define TYPE_A 1
11 #define TYPE_CNAME 5
12
13 /*
14 * The DNS header structure
15 */
16 struct dnshdr {
17 uint16_t id;
18 uint16_t flags;
19 /* number of entries in the question section */
20 uint16_t qdcount;
21 /* number of resource records in the answer section */
22 uint16_t ancount;
23 /* number of name server resource records in the authority records section*/
24 uint16_t nscount;
25 /* number of resource records in the additional records section */
26 uint16_t arcount;
27 } __attribute__ ((packed));
28
29 /*
30 * The DNS query structure
31 */
32 struct dnsquery {
33 uint16_t qtype;
34 uint16_t qclass;
35 } __attribute__ ((packed));
36
37 /*
38 * The DNS Resource recodes structure
39 */
40 struct dnsrr {
41 uint16_t type;
42 uint16_t class;
43 uint32_t ttl;
44 uint16_t rdlength; /* The lenght of this rr data */
45 char rdata[];
46 } __attribute__ ((packed));
47
48
49 #define DNS_PORT htons(53) /* Default DNS port */
50 #define DNS_MAX_SERVERS 4 /* Max no of DNS servers */
51
52 uint32_t dns_server[DNS_MAX_SERVERS] = {0, };
53
54
55 /*
56 * Turn a string in _src_ into a DNS "label set" in _dst_; returns the
57 * number of dots encountered. On return, *dst is updated.
58 */
dns_mangle(char ** dst,const char * p)59 int dns_mangle(char **dst, const char *p)
60 {
61 char *q = *dst;
62 char *count_ptr;
63 char c;
64 int dots = 0;
65
66 count_ptr = q;
67 *q++ = 0;
68
69 while (1) {
70 c = *p++;
71 if (c == 0 || c == ':' || c == '/')
72 break;
73 if (c == '.') {
74 dots++;
75 count_ptr = q;
76 *q++ = 0;
77 continue;
78 }
79
80 *count_ptr += 1;
81 *q++ = c;
82 }
83
84 if (*count_ptr)
85 *q++ = 0;
86
87 /* update the strings */
88 *dst = q;
89 return dots;
90 }
91
92
93 /*
94 * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s2_
95 * is allowed pointers relative to a packet in buf.
96 *
97 */
dns_compare(const void * s1,const void * s2,const void * buf)98 static bool dns_compare(const void *s1, const void *s2, const void *buf)
99 {
100 const uint8_t *q = s1;
101 const uint8_t *p = s2;
102 unsigned int c0, c1;
103
104 while (1) {
105 c0 = p[0];
106 if (c0 >= 0xc0) {
107 /* Follow pointer */
108 c1 = p[1];
109 p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
110 } else if (c0) {
111 c0++; /* Include the length byte */
112 if (memcmp(q, p, c0))
113 return false;
114 q += c0;
115 p += c0;
116 } else {
117 return *q == 0;
118 }
119 }
120 }
121
122 /*
123 * Copy a DNS label into a buffer, considering the possibility that we might
124 * have to follow pointers relative to "buf".
125 * Returns a pointer to the first free byte *after* the terminal null.
126 */
dns_copylabel(void * dst,const void * src,const void * buf)127 static void *dns_copylabel(void *dst, const void *src, const void *buf)
128 {
129 uint8_t *q = dst;
130 const uint8_t *p = src;
131 unsigned int c0, c1;
132
133 while (1) {
134 c0 = p[0];
135 if (c0 >= 0xc0) {
136 /* Follow pointer */
137 c1 = p[1];
138 p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
139 } else if (c0) {
140 c0++; /* Include the length byte */
141 memcpy(q, p, c0);
142 p += c0;
143 q += c0;
144 } else {
145 *q++ = 0;
146 return q;
147 }
148 }
149 }
150
151 /*
152 * Skip past a DNS label set in DS:SI
153 */
dns_skiplabel(char * label)154 static char *dns_skiplabel(char *label)
155 {
156 uint8_t c;
157
158 while (1) {
159 c = *label++;
160 if (c >= 0xc0)
161 return ++label; /* pointer is two bytes */
162 if (c == 0)
163 return label;
164 label += c;
165 }
166 }
167
168 extern const uint8_t TimeoutTable[];
169 extern uint16_t get_port(void);
170 extern void free_port(uint16_t port);
171
172 /*
173 * parse the ip_str and return the ip address with *res.
174 * return true if the whole string was consumed and the result
175 * was valid.
176 *
177 */
parse_dotquad(const char * ip_str,uint32_t * res)178 static bool parse_dotquad(const char *ip_str, uint32_t *res)
179 {
180 const char *p = ip_str;
181 uint8_t part = 0;
182 uint32_t ip = 0;
183 int i;
184
185 for (i = 0; i < 4; i++) {
186 while (is_digit(*p)) {
187 part = part * 10 + *p - '0';
188 p++;
189 }
190 if (i != 3 && *p != '.')
191 return false;
192
193 ip = (ip << 8) | part;
194 part = 0;
195 p++;
196 }
197 p--;
198
199 *res = htonl(ip);
200 return *p == '\0';
201 }
202
203 /*
204 * Actual resolver function
205 * Points to a null-terminated or :-terminated string in _name_
206 * and returns the ip addr in _ip_ if it exists and can be found.
207 * If _ip_ = 0 on exit, the lookup failed. _name_ will be updated
208 *
209 * XXX: probably need some caching here.
210 */
dns_resolv(const char * name)211 __export uint32_t dns_resolv(const char *name)
212 {
213 static char __lowmem DNSSendBuf[PKTBUF_SIZE];
214 static char __lowmem DNSRecvBuf[PKTBUF_SIZE];
215 char *p;
216 int err;
217 int dots;
218 int same;
219 int rd_len;
220 int ques, reps; /* number of questions and replies */
221 uint8_t timeout;
222 const uint8_t *timeout_ptr = TimeoutTable;
223 uint32_t oldtime;
224 uint32_t srv;
225 uint32_t *srv_ptr;
226 struct dnshdr *hd1 = (struct dnshdr *)DNSSendBuf;
227 struct dnshdr *hd2 = (struct dnshdr *)DNSRecvBuf;
228 struct dnsquery *query;
229 struct dnsrr *rr;
230 static __lowmem struct s_PXENV_UDP_WRITE udp_write;
231 static __lowmem struct s_PXENV_UDP_READ udp_read;
232 uint16_t local_port;
233 uint32_t result = 0;
234
235 /*
236 * Return failure on an empty input... this can happen during
237 * some types of URL parsing, and this is the easiest place to
238 * check for it.
239 */
240 if (!name || !*name)
241 return 0;
242
243 /* If it is a valid dot quad, just return that value */
244 if (parse_dotquad(name, &result))
245 return result;
246
247 /* Make sure we have at least one valid DNS server */
248 if (!dns_server[0])
249 return 0;
250
251 /* Get a local port number */
252 local_port = get_port();
253
254 /* First, fill the DNS header struct */
255 hd1->id++; /* New query ID */
256 hd1->flags = htons(0x0100); /* Recursion requested */
257 hd1->qdcount = htons(1); /* One question */
258 hd1->ancount = 0; /* No answers */
259 hd1->nscount = 0; /* No NS */
260 hd1->arcount = 0; /* No AR */
261
262 p = DNSSendBuf + sizeof(struct dnshdr);
263 dots = dns_mangle(&p, name); /* store the CNAME */
264
265 if (!dots) {
266 p--; /* Remove final null */
267 /* Uncompressed DNS label set so it ends in null */
268 p = stpcpy(p, LocalDomain);
269 }
270
271 /* Fill the DNS query packet */
272 query = (struct dnsquery *)p;
273 query->qtype = htons(TYPE_A);
274 query->qclass = htons(CLASS_IN);
275 p += sizeof(struct dnsquery);
276
277 /* Now send it to name server */
278 timeout_ptr = TimeoutTable;
279 timeout = *timeout_ptr++;
280 srv_ptr = dns_server;
281 while (timeout) {
282 srv = *srv_ptr++;
283 if (!srv) {
284 srv_ptr = dns_server;
285 srv = *srv_ptr++;
286 }
287
288 udp_write.status = 0;
289 udp_write.ip = srv;
290 udp_write.gw = gateway(srv);
291 udp_write.src_port = local_port;
292 udp_write.dst_port = DNS_PORT;
293 udp_write.buffer_size = p - DNSSendBuf;
294 udp_write.buffer = FAR_PTR(DNSSendBuf);
295 err = pxe_call(PXENV_UDP_WRITE, &udp_write);
296 if (err || udp_write.status)
297 continue;
298
299 oldtime = jiffies();
300 do {
301 if (jiffies() - oldtime >= timeout)
302 goto again;
303
304 udp_read.status = 0;
305 udp_read.src_ip = srv;
306 udp_read.dest_ip = IPInfo.myip;
307 udp_read.s_port = DNS_PORT;
308 udp_read.d_port = local_port;
309 udp_read.buffer_size = PKTBUF_SIZE;
310 udp_read.buffer = FAR_PTR(DNSRecvBuf);
311 err = pxe_call(PXENV_UDP_READ, &udp_read);
312 } while (err || udp_read.status || hd2->id != hd1->id);
313
314 if ((hd2->flags ^ 0x80) & htons(0xf80f))
315 goto badness;
316
317 ques = htons(hd2->qdcount); /* Questions */
318 reps = htons(hd2->ancount); /* Replies */
319 p = DNSRecvBuf + sizeof(struct dnshdr);
320 while (ques--) {
321 p = dns_skiplabel(p); /* Skip name */
322 p += 4; /* Skip question trailer */
323 }
324
325 /* Parse the replies */
326 while (reps--) {
327 same = dns_compare(DNSSendBuf + sizeof(struct dnshdr),
328 p, DNSRecvBuf);
329 p = dns_skiplabel(p);
330 rr = (struct dnsrr *)p;
331 rd_len = ntohs(rr->rdlength);
332 if (same && ntohs(rr->class) == CLASS_IN) {
333 switch (ntohs(rr->type)) {
334 case TYPE_A:
335 if (rd_len == 4) {
336 result = *(uint32_t *)rr->rdata;
337 goto done;
338 }
339 break;
340 case TYPE_CNAME:
341 dns_copylabel(DNSSendBuf + sizeof(struct dnshdr),
342 rr->rdata, DNSRecvBuf);
343 /*
344 * We should probably rescan the packet from the top
345 * here, and technically we might have to send a whole
346 * new request here...
347 */
348 break;
349 default:
350 break;
351 }
352 }
353
354 /* not the one we want, try next */
355 p += sizeof(struct dnsrr) + rd_len;
356 }
357
358 badness:
359 /*
360 *
361 ; We got back no data from this server.
362 ; Unfortunately, for a recursive, non-authoritative
363 ; query there is no such thing as an NXDOMAIN reply,
364 ; which technically means we can't draw any
365 ; conclusions. However, in practice that means the
366 ; domain doesn't exist. If this turns out to be a
367 ; problem, we may want to add code to go through all
368 ; the servers before giving up.
369
370 ; If the DNS server wasn't capable of recursion, and
371 ; isn't capable of giving us an authoritative reply
372 ; (i.e. neither AA or RA set), then at least try a
373 ; different setver...
374 */
375 if (hd2->flags == htons(0x480))
376 continue;
377
378 break; /* failed */
379
380 again:
381 continue;
382 }
383
384 done:
385 free_port(local_port); /* Return port number to the free pool */
386
387 return result;
388 }
389