You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
3.9 KiB
173 lines
3.9 KiB
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
func main() {
|
|
fmt.Println("hello")
|
|
|
|
cfg, err := LoadConfig("./config.json")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
blockLists := []Blocklist{}
|
|
for _, bl := range cfg.Blocklists {
|
|
blocklist, err := FetchBlockList(bl)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
log.Printf("block list '%s' -> %d", bl, len(blocklist.Domains))
|
|
blockLists = append(blockLists, blocklist)
|
|
}
|
|
|
|
srv := &dns.Server{
|
|
Addr: ":53",
|
|
Net: "udp",
|
|
Handler: &handler{
|
|
Config: *cfg,
|
|
Cache: DNSCache(map[string]DNSCacheRecord{}),
|
|
BlockLists: blockLists,
|
|
},
|
|
}
|
|
|
|
if err := srv.ListenAndServe(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type handler struct {
|
|
Config Configuration
|
|
BlockLists []Blocklist
|
|
Cache DNSCache
|
|
}
|
|
|
|
type DNSCacheRecord struct {
|
|
Expiration time.Time
|
|
IP net.IP
|
|
}
|
|
|
|
type DNSCache map[string]DNSCacheRecord
|
|
|
|
func (d DNSCache) Get(domain string, msg *dns.Msg) bool {
|
|
if record, ok := d[domain]; ok && time.Now().UTC().Before(record.Expiration) {
|
|
msg.Answer = append(msg.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
A: record.IP,
|
|
})
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (d DNSCache) Set(domain string, IP net.IP, ttl time.Duration) {
|
|
if IP != nil && ttl > 0 {
|
|
log.Printf("setting %s to cache", domain)
|
|
d[domain] = DNSCacheRecord{
|
|
Expiration: time.Now().UTC().Add(ttl),
|
|
IP: IP,
|
|
}
|
|
} else {
|
|
delete(d, domain)
|
|
}
|
|
}
|
|
|
|
func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
msg := dns.Msg{}
|
|
msg.SetReply(r)
|
|
|
|
switch r.Question[0].Qtype {
|
|
case dns.TypeA:
|
|
msg.Authoritative = true
|
|
domain := msg.Question[0].Name
|
|
fmt.Println("got request", domain)
|
|
|
|
if address, ok := h.Config.Records[domain[:len(domain)-1]]; ok && address.Type == "A" {
|
|
msg.Answer = append(msg.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
A: net.ParseIP(address.Record),
|
|
})
|
|
} else if list, block := shouldBlock(h.BlockLists, domain); block {
|
|
log.Printf("blocked dns query for '%s' from '%s'", domain, list)
|
|
msg.Answer = append(msg.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
A: net.ParseIP("127.0.0.1"),
|
|
})
|
|
} else if !h.Cache.Get(domain, &msg) {
|
|
msg, ip, ttl, err := recursiveResolve(domain, h.Config.Upstream...)
|
|
if err != nil {
|
|
log.Printf("got an error trying to resolve '%s': %v", domain, err)
|
|
break
|
|
}
|
|
h.Cache.Set(domain, ip, ttl)
|
|
msg.SetReply(r)
|
|
w.WriteMsg(msg)
|
|
return
|
|
}
|
|
}
|
|
|
|
w.WriteMsg(&msg)
|
|
}
|
|
|
|
func shouldBlock(bls []Blocklist, domain string) (string, bool) {
|
|
for _, b := range bls {
|
|
if _, ok := b.Domains[domain[:len(domain)-1]]; ok {
|
|
return b.Source, true
|
|
}
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
func recursiveResolve(domain string, servers ...string) (*dns.Msg, net.IP, time.Duration, error) {
|
|
c := dns.Client{
|
|
DialTimeout: time.Second,
|
|
ReadTimeout: time.Second,
|
|
}
|
|
|
|
for _, server := range servers {
|
|
m := dns.Msg{}
|
|
m.SetQuestion(domain, dns.TypeA)
|
|
m.Compress = true
|
|
r, t, err := c.Exchange(&m, server+":53")
|
|
if err != nil {
|
|
return nil, nil, 0, err
|
|
}
|
|
|
|
log.Printf("Took %v", t)
|
|
if len(r.Answer) == 0 {
|
|
continue
|
|
}
|
|
|
|
res := dns.Msg{}
|
|
for _, ans := range r.Answer {
|
|
switch ans.(type) {
|
|
case *dns.A:
|
|
Arecord := ans.(*dns.A)
|
|
res.Answer = append(res.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
A: Arecord.A,
|
|
})
|
|
return &res, Arecord.A, time.Second * time.Duration(Arecord.Hdr.Ttl), nil
|
|
case *dns.CNAME:
|
|
Crecord := ans.(*dns.CNAME)
|
|
res.Answer = append(res.Answer, &dns.CNAME{
|
|
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 60},
|
|
Target: Crecord.Target,
|
|
})
|
|
return &res, nil, 0, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, nil, 0, errors.New("no record found")
|
|
}
|