You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
4.6 KiB

package main
import (
"errors"
"fmt"
"log"
"net"
"os"
"sync"
"time"
"github.com/miekg/dns"
)
type dnsHandler struct {
Config Configuration
blockForwardIP net.IP
logger *log.Logger
Resolver DNSResolver
Blocklist BlocklistManager
Cache DNSCacher
}
//NewDNSHandler creates a new DNS server handler
func NewDNSHandler(blockForwardIP string, cache DNSCacher, blocklist BlocklistManager, r DNSResolver, cfg Configuration) (dns.Handler, error) {
l := log.New(os.Stdout, "[DNS Server] ", log.LUTC|log.Lmicroseconds|log.Lshortfile)
if cfg.UseHosts {
hostsConf, _ := dns.ClientConfigFromFile("/etc/resolv.conf")
hostsServers := make([]net.IP, len(hostsConf.Servers))
for idx, hip := range hostsConf.Servers {
hostsServers[idx] = net.ParseIP(hip)
}
r = append(r, hostsServers...)
}
return &dnsHandler{
logger: l,
blockForwardIP: net.ParseIP(blockForwardIP),
Resolver: r,
Config: cfg,
Blocklist: blocklist,
Cache: cache,
}, nil
}
func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
var msg dns.Msg
msg.Authoritative = true
msg.RecursionAvailable = true
msg.Answer = []dns.RR{}
msg.SetReply(r)
for _, question := range r.Question {
rawDomain := question.Name
domain := rawDomain[:len(rawDomain)-1]
header := dns.RR_Header{
Name: rawDomain,
Rrtype: question.Qtype,
Class: dns.ClassINET,
Ttl: 60,
}
if address, ok := h.Config.Records[domain]; ok {
switch address.Type {
case "A":
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: net.ParseIP(address.Record)})
case "CNAME":
answer, _, err := h.Resolver.Resolve(fmt.Sprintf("%s.", address.Record))
if err != nil {
h.logger.Printf("could not resolve custom CNAME record: %v", err)
continue
}
h.logger.Printf("%v", answer)
msg.Answer = append(msg.Answer, answer...)
default:
h.logger.Printf("type '%s' not supported for custom records.", address.Type)
}
} else if list, block := h.Blocklist.IsBlacklisted(domain); block {
h.logger.Printf("blocked dns query for '%s' from list '%s'", domain, list)
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: h.blockForwardIP})
} else if cacheAnswers, ok := h.Cache.Get(domain); ok && len(cacheAnswers) > 0 {
msg.Answer = append(msg.Answer, cacheAnswers...)
} else {
msg.Authoritative = false
answers, ttl, err := h.Resolver.Resolve(rawDomain)
if err != nil {
h.logger.Printf("error resolving '%s': %v", domain, err)
continue
}
msg.Answer = append(msg.Answer, answers...)
h.Cache.Set(domain, ttl, answers)
}
}
if err := w.WriteMsg(&msg); err != nil {
h.logger.Printf("could not write msg: %v", err)
}
}
func shouldBlock(bls []Blocklist, domain string) (string, bool) {
for _, b := range bls {
if _, ok := b.Domains[domain[:len(domain)-1]]; ok {
return b.Source, true
}
}
return "", false
}
//DNSResolver resolves dns recursively from index 0 to n
type DNSResolver []net.IP
//Resolve resolves the domain specified
func (resolver DNSResolver) Resolve(domain string) ([]dns.RR, time.Duration, error) {
c := dns.Client{
DialTimeout: time.Second,
ReadTimeout: time.Second,
}
for _, server := range resolver {
m := dns.Msg{}
m.SetQuestion(domain, dns.TypeA)
m.Compress = true
r, t, err := c.Exchange(&m, fmt.Sprintf("%s:53", server))
if err != nil || len(r.Answer) == 0 {
// try another server
log.Println("GOT AN ERR RESOLVING", err)
continue
}
log.Printf("Took %v", t)
return r.Answer, time.Second * 60, nil
}
return nil, 0, errors.New("no record found")
}
//DNSCacher is a cache for dns records
type DNSCacher interface {
Get(domain string) ([]dns.RR, bool)
Set(domain string, ttl time.Duration, answers []dns.RR)
}
type dnsCacheRecord struct {
Expiration time.Time
Records []dns.RR
}
type memoryDNSCacher struct {
TTL time.Duration
cache map[string]dnsCacheRecord
sync.Mutex
}
func (d *memoryDNSCacher) Get(domain string) ([]dns.RR, bool) {
d.Lock()
defer d.Unlock()
if d.cache == nil {
d.cache = map[string]dnsCacheRecord{}
}
if record, ok := d.cache[domain]; ok && time.Now().UTC().Before(record.Expiration) {
return record.Records, true
}
return nil, false
}
func (d *memoryDNSCacher) Set(domain string, ttl time.Duration, records []dns.RR) {
d.Lock()
defer d.Unlock()
if d.cache == nil {
d.cache = map[string]dnsCacheRecord{}
}
if records != nil && len(records) > 0 && ttl > 0 {
log.Printf("setting %s to cache", domain)
d.cache[domain] = dnsCacheRecord{
Expiration: time.Now().UTC().Add(ttl),
Records: records,
}
} else {
delete(d.cache, domain)
}
}