You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
3.9 KiB

package main
import (
"errors"
"fmt"
"log"
"net"
"time"
"github.com/miekg/dns"
)
func main() {
fmt.Println("hello")
cfg, err := LoadConfig("./config.json")
if err != nil {
log.Fatal(err)
}
blockLists := []Blocklist{}
for _, bl := range cfg.Blocklists {
blocklist, err := FetchBlockList(bl)
if err != nil {
continue
}
log.Printf("block list '%s' -> %d", bl, len(blocklist.Domains))
blockLists = append(blockLists, blocklist)
}
srv := &dns.Server{
Addr: ":53",
Net: "udp",
Handler: &handler{
Config: *cfg,
Cache: DNSCache(map[string]DNSCacheRecord{}),
BlockLists: blockLists,
},
}
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}
}
type handler struct {
Config Configuration
BlockLists []Blocklist
Cache DNSCache
}
type DNSCacheRecord struct {
Expiration time.Time
IP net.IP
}
type DNSCache map[string]DNSCacheRecord
func (d DNSCache) Get(domain string, msg *dns.Msg) bool {
if record, ok := d[domain]; ok && time.Now().UTC().Before(record.Expiration) {
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: record.IP,
})
return true
}
return false
}
func (d DNSCache) Set(domain string, IP net.IP, ttl time.Duration) {
if IP != nil && ttl > 0 {
log.Printf("setting %s to cache", domain)
d[domain] = DNSCacheRecord{
Expiration: time.Now().UTC().Add(ttl),
IP: IP,
}
} else {
delete(d, domain)
}
}
func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg := dns.Msg{}
msg.SetReply(r)
switch r.Question[0].Qtype {
case dns.TypeA:
msg.Authoritative = true
domain := msg.Question[0].Name
fmt.Println("got request", domain)
if address, ok := h.Config.Records[domain[:len(domain)-1]]; ok && address.Type == "A" {
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(address.Record),
})
} else if list, block := shouldBlock(h.BlockLists, domain); block {
log.Printf("blocked dns query for '%s' from '%s'", domain, list)
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("127.0.0.1"),
})
} else if !h.Cache.Get(domain, &msg) {
msg, ip, ttl, err := recursiveResolve(domain, h.Config.Upstream...)
if err != nil {
log.Printf("got an error trying to resolve '%s': %v", domain, err)
break
}
h.Cache.Set(domain, ip, ttl)
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}
w.WriteMsg(&msg)
}
func shouldBlock(bls []Blocklist, domain string) (string, bool) {
for _, b := range bls {
if _, ok := b.Domains[domain[:len(domain)-1]]; ok {
return b.Source, true
}
}
return "", false
}
func recursiveResolve(domain string, servers ...string) (*dns.Msg, net.IP, time.Duration, error) {
c := dns.Client{
DialTimeout: time.Second,
ReadTimeout: time.Second,
}
for _, server := range servers {
m := dns.Msg{}
m.SetQuestion(domain, dns.TypeA)
m.Compress = true
r, t, err := c.Exchange(&m, server+":53")
if err != nil {
return nil, nil, 0, err
}
log.Printf("Took %v", t)
if len(r.Answer) == 0 {
continue
}
res := dns.Msg{}
for _, ans := range r.Answer {
switch ans.(type) {
case *dns.A:
Arecord := ans.(*dns.A)
res.Answer = append(res.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: Arecord.A,
})
return &res, Arecord.A, time.Second * time.Duration(Arecord.Hdr.Ttl), nil
case *dns.CNAME:
Crecord := ans.(*dns.CNAME)
res.Answer = append(res.Answer, &dns.CNAME{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 60},
Target: Crecord.Target,
})
return &res, nil, 0, nil
}
}
}
return nil, nil, 0, errors.New("no record found")
}