diff --git a/dns.c b/dns.c index 184b82f..218f1e5 100644 --- a/dns.c +++ b/dns.c @@ -1,3 +1,4 @@ +#include #include #include #include @@ -8,11 +9,35 @@ #include #include #include +#include #include "dns.h" #define BUFLEN 512 +#if defined IP_RECVDSTADDR +# define DSTADDR_SOCKOPT IP_RECVDSTADDR +# define DSTADDR_DATASIZE (CMSG_SPACE(sizeof(struct in_addr))) +# define dstaddr(x) (CMSG_DATA(x)) +#elif defined IP_PKTINFO +struct in_pktinfo { + unsigned int ipi_ifindex; /* Interface index */ + struct in_addr ipi_spec_dst; /* Local address */ + struct in_addr ipi_addr; /* Header Destination address */ +}; + +# define DSTADDR_SOCKOPT IP_PKTINFO +# define DSTADDR_DATASIZE (CMSG_SPACE(sizeof(struct in_pktinfo))) +# define dstaddr(x) (&(((struct in_pktinfo *)(CMSG_DATA(x)))->ipi_addr)) +#else +# error "can't determine socket option" +#endif + +union control_data { + struct cmsghdr cmsg; + unsigned char data[DSTADDR_DATASIZE]; +}; + typedef enum { CLASS_IN = 1, QCLASS_ANY = 255 @@ -348,12 +373,21 @@ int dnsserver(dns_opt_t *opt) { if (senderSocket == -1) return -3; + int replySocket; if (listenSocket == -1) { struct sockaddr_in si_me; if ((listenSocket=socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP))==-1) { listenSocket = -1; return -1; } + replySocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (replySocket == -1) + { + close(listenSocket); + return -1; + } + int sockopt = 1; + setsockopt(listenSocket, IPPROTO_IP, DSTADDR_SOCKOPT, &sockopt, sizeof sockopt); memset((char *) &si_me, 0, sizeof(si_me)); si_me.sin_family = AF_INET; si_me.sin_port = htons(opt->port); @@ -361,18 +395,48 @@ int dnsserver(dns_opt_t *opt) { if (bind(listenSocket, (struct sockaddr*)&si_me, sizeof(si_me))==-1) return -2; } + unsigned char inbuf[BUFLEN], outbuf[BUFLEN]; - do { - socklen_t si_other_len = sizeof(si_other); - ssize_t insize = recvfrom(listenSocket, inbuf, BUFLEN, 0, (struct sockaddr*)&si_other, &si_other_len); + struct iovec iov[1] = { + { + .iov_base = inbuf, + .iov_len = sizeof(inbuf), + }, + }; + union control_data cmsg; + struct msghdr msg = { + .msg_name = &si_other, + .msg_namelen = sizeof(si_other), + .msg_iov = iov, + .msg_iovlen = 1, + .msg_control = &cmsg, + .msg_controllen = sizeof(cmsg), + }; + for (; 1; ++(opt->nRequests)) + { + ssize_t insize = recvmsg(listenSocket, &msg, 0); unsigned char *addr = (unsigned char*)&si_other.sin_addr.s_addr; // printf("DNS: Request %llu from %i.%i.%i.%i:%i of %i bytes\n", (unsigned long long)(opt->nRequests), addr[0], addr[1], addr[2], addr[3], ntohs(si_other.sin_port), (int)insize); - opt->nRequests++; - if (insize > 0) { - ssize_t ret = dnshandle(opt, inbuf, insize, outbuf); - if (ret > 0) - sendto(listenSocket, outbuf, ret, 0, (struct sockaddr*)&si_other, sizeof(si_other)); + if (insize <= 0) + continue; + + ssize_t ret = dnshandle(opt, inbuf, insize, outbuf); + if (ret <= 0) + continue; + + bool handled = false; + for (struct cmsghdr*hdr = CMSG_FIRSTHDR(&msg); hdr; hdr = CMSG_NXTHDR(&msg, hdr)) + { + if (hdr->cmsg_level == IPPROTO_IP && hdr->cmsg_type == DSTADDR_SOCKOPT) + { + msg.msg_iov[0].iov_base = outbuf; + sendmsg(listenSocket, &msg, 0); + msg.msg_iov[0].iov_base = inbuf; + handled = true; + } } - } while(1); + if (!handled) + sendto(listenSocket, outbuf, ret, 0, (struct sockaddr*)&si_other, sizeof(si_other)); + } return 0; }