You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
4.6 KiB
186 lines
4.6 KiB
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type dnsHandler struct {
|
|
Config Configuration
|
|
blockForwardIP net.IP
|
|
logger *log.Logger
|
|
Resolver DNSResolver
|
|
Blocklist BlocklistManager
|
|
Cache DNSCacher
|
|
}
|
|
|
|
//NewDNSHandler creates a new DNS server handler
|
|
func NewDNSHandler(blockForwardIP string, cache DNSCacher, blocklist BlocklistManager, r DNSResolver, cfg Configuration) (dns.Handler, error) {
|
|
l := log.New(os.Stdout, "[DNS Server] ", log.LUTC|log.Lmicroseconds|log.Lshortfile)
|
|
|
|
if cfg.UseHosts {
|
|
hostsConf, _ := dns.ClientConfigFromFile("/etc/resolv.conf")
|
|
hostsServers := make([]net.IP, len(hostsConf.Servers))
|
|
for idx, hip := range hostsConf.Servers {
|
|
hostsServers[idx] = net.ParseIP(hip)
|
|
}
|
|
r = append(r, hostsServers...)
|
|
}
|
|
|
|
return &dnsHandler{
|
|
logger: l,
|
|
blockForwardIP: net.ParseIP(blockForwardIP),
|
|
Resolver: r,
|
|
Config: cfg,
|
|
Blocklist: blocklist,
|
|
Cache: cache,
|
|
}, nil
|
|
}
|
|
|
|
func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
var msg dns.Msg
|
|
msg.Authoritative = true
|
|
msg.RecursionAvailable = true
|
|
msg.Answer = []dns.RR{}
|
|
msg.SetReply(r)
|
|
|
|
for _, question := range r.Question {
|
|
rawDomain := question.Name
|
|
domain := rawDomain[:len(rawDomain)-1]
|
|
header := dns.RR_Header{
|
|
Name: rawDomain,
|
|
Rrtype: question.Qtype,
|
|
Class: dns.ClassINET,
|
|
Ttl: 60,
|
|
}
|
|
|
|
if address, ok := h.Config.Records[domain]; ok {
|
|
switch address.Type {
|
|
case "A":
|
|
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: net.ParseIP(address.Record)})
|
|
case "CNAME":
|
|
answer, _, err := h.Resolver.Resolve(fmt.Sprintf("%s.", address.Record))
|
|
if err != nil {
|
|
h.logger.Printf("could not resolve custom CNAME record: %v", err)
|
|
continue
|
|
}
|
|
h.logger.Printf("%v", answer)
|
|
msg.Answer = append(msg.Answer, answer...)
|
|
default:
|
|
h.logger.Printf("type '%s' not supported for custom records.", address.Type)
|
|
}
|
|
} else if list, block := h.Blocklist.IsBlacklisted(domain); block {
|
|
h.logger.Printf("blocked dns query for '%s' from list '%s'", domain, list)
|
|
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: h.blockForwardIP})
|
|
} else if cacheAnswers, ok := h.Cache.Get(domain); ok && len(cacheAnswers) > 0 {
|
|
msg.Answer = append(msg.Answer, cacheAnswers...)
|
|
} else {
|
|
msg.Authoritative = false
|
|
answers, ttl, err := h.Resolver.Resolve(rawDomain)
|
|
if err != nil {
|
|
h.logger.Printf("error resolving '%s': %v", domain, err)
|
|
continue
|
|
}
|
|
|
|
msg.Answer = append(msg.Answer, answers...)
|
|
h.Cache.Set(domain, ttl, answers)
|
|
}
|
|
}
|
|
|
|
if err := w.WriteMsg(&msg); err != nil {
|
|
h.logger.Printf("could not write msg: %v", err)
|
|
}
|
|
}
|
|
|
|
func shouldBlock(bls []Blocklist, domain string) (string, bool) {
|
|
for _, b := range bls {
|
|
if _, ok := b.Domains[domain[:len(domain)-1]]; ok {
|
|
return b.Source, true
|
|
}
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
//DNSResolver resolves dns recursively from index 0 to n
|
|
type DNSResolver []net.IP
|
|
|
|
//Resolve resolves the domain specified
|
|
func (resolver DNSResolver) Resolve(domain string) ([]dns.RR, time.Duration, error) {
|
|
c := dns.Client{
|
|
DialTimeout: time.Second,
|
|
ReadTimeout: time.Second,
|
|
}
|
|
|
|
for _, server := range resolver {
|
|
m := dns.Msg{}
|
|
m.SetQuestion(domain, dns.TypeA)
|
|
m.Compress = true
|
|
r, t, err := c.Exchange(&m, fmt.Sprintf("%s:53", server))
|
|
if err != nil || len(r.Answer) == 0 {
|
|
// try another server
|
|
log.Println("GOT AN ERR RESOLVING", err)
|
|
continue
|
|
}
|
|
|
|
log.Printf("Took %v", t)
|
|
return r.Answer, time.Second * 60, nil
|
|
}
|
|
|
|
return nil, 0, errors.New("no record found")
|
|
}
|
|
|
|
//DNSCacher is a cache for dns records
|
|
type DNSCacher interface {
|
|
Get(domain string) ([]dns.RR, bool)
|
|
Set(domain string, ttl time.Duration, answers []dns.RR)
|
|
}
|
|
|
|
type dnsCacheRecord struct {
|
|
Expiration time.Time
|
|
Records []dns.RR
|
|
}
|
|
|
|
type memoryDNSCacher struct {
|
|
TTL time.Duration
|
|
cache map[string]dnsCacheRecord
|
|
sync.Mutex
|
|
}
|
|
|
|
func (d *memoryDNSCacher) Get(domain string) ([]dns.RR, bool) {
|
|
d.Lock()
|
|
defer d.Unlock()
|
|
if d.cache == nil {
|
|
d.cache = map[string]dnsCacheRecord{}
|
|
}
|
|
if record, ok := d.cache[domain]; ok && time.Now().UTC().Before(record.Expiration) {
|
|
return record.Records, true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (d *memoryDNSCacher) Set(domain string, ttl time.Duration, records []dns.RR) {
|
|
d.Lock()
|
|
defer d.Unlock()
|
|
if d.cache == nil {
|
|
d.cache = map[string]dnsCacheRecord{}
|
|
}
|
|
|
|
if records != nil && len(records) > 0 && ttl > 0 {
|
|
log.Printf("setting %s to cache", domain)
|
|
d.cache[domain] = dnsCacheRecord{
|
|
Expiration: time.Now().UTC().Add(ttl),
|
|
Records: records,
|
|
}
|
|
} else {
|
|
delete(d.cache, domain)
|
|
}
|
|
}
|