Compare commits

...

2 Commits

Author SHA1 Message Date
Adam Veldhousen cd553b4fad
use correct name in dns header, fixes blacklisting
4 years ago
Adam Veldhousen e40b3a7dba
fixup flags
4 years ago

@ -86,7 +86,7 @@ func TestFetchBlockList(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := FetchBlockList(tt.source)
got, err := CreateBlockList(tt.source)
if (err != nil) != tt.wantErr {
t.Errorf("FetchBlockList() error = %v, wantErr %v", err, tt.wantErr)
return

@ -13,22 +13,24 @@ import (
)
type dnsHandler struct {
Config Configuration
logger *log.Logger
Resolver DNSResolver
Blocklist BlocklistManager
Cache DNSCacher
Config Configuration
httpBindIP net.IP
logger *log.Logger
Resolver DNSResolver
Blocklist BlocklistManager
Cache DNSCacher
}
//NewDNSHandler creates a new DNS server handler
func NewDNSHandler(cache DNSCacher, blocklist BlocklistManager, r DNSResolver, cfg Configuration) (dns.Handler, error) {
func NewDNSHandler(httpBindIP string, cache DNSCacher, blocklist BlocklistManager, r DNSResolver, cfg Configuration) (dns.Handler, error) {
l := log.New(os.Stdout, "[DNS Server] ", log.LUTC|log.Lmicroseconds|log.Lshortfile)
return &dnsHandler{
logger: l,
Resolver: r,
Config: cfg,
Blocklist: blocklist,
Cache: cache,
logger: l,
httpBindIP: net.ParseIP(httpBindIP),
Resolver: r,
Config: cfg,
Blocklist: blocklist,
Cache: cache,
}, nil
}
@ -36,13 +38,14 @@ func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
var msg dns.Msg
msg.Authoritative = true
msg.RecursionAvailable = true
msg.Answer = []dns.RR{}
msg.SetReply(r)
for _, question := range r.Question {
rawDomain := question.Name
domain := rawDomain[:len(rawDomain)-1]
header := dns.RR_Header{
Name: domain,
Name: rawDomain,
Rrtype: question.Qtype,
Class: dns.ClassINET,
Ttl: 60,
@ -65,7 +68,7 @@ func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
} else if list, block := h.Blocklist.IsBlacklisted(domain); block {
h.logger.Printf("blocked dns query for '%s' from list '%s'", domain, list)
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: net.ParseIP("127.0.0.1")})
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: h.httpBindIP})
} else if cacheAnswers, ok := h.Cache.Get(domain); ok && len(cacheAnswers) > 0 {
msg.Answer = append(msg.Answer, cacheAnswers...)
} else {
@ -81,7 +84,9 @@ func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}
w.WriteMsg(&msg)
if err := w.WriteMsg(&msg); err != nil {
h.logger.Printf("could not write msg: %v", err)
}
}
func shouldBlock(bls []Blocklist, domain string) (string, bool) {

@ -2,6 +2,7 @@ package main
import (
"flag"
"fmt"
"log"
"net"
"net/http"
@ -12,12 +13,13 @@ import (
var (
configFilePath = flag.String("config", "./config.json", "config file path")
httpAddress = flag.String("bind-http", "127.0.0.1", "interface to bind the HTTP server to")
dnsAddress = flag.String("bind-dns", "127.0.0.1", "interface to bind the DNS server to")
verbose = flag.Bool("verbose", false, "enable verbose logging")
httpAddress = flag.String("bind-http", "127.0.0.1", "interface to bind the HTTP server to (0.0.0.0 for all)")
dnsAddress = flag.String("bind-dns", "127.0.0.1", "interface to bind the DNS server to (0.0.0.0 for all)")
)
func main() {
log.SetPrefix("[Entrypoint] ")
log.SetFlags(log.LUTC | log.Lshortfile)
flag.Parse()
cfg, err := LoadConfig(*configFilePath)
@ -26,8 +28,9 @@ func main() {
}
go func() {
log.Println("HTTP server listening")
if err := http.ListenAndServe(":80", http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
httpAddr := fmt.Sprintf("%s:80", *httpAddress)
log.Printf("HTTP server listening @ %s", httpAddr)
if err := http.ListenAndServe(httpAddr, http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
log.Printf("got request for: %s", req.URL.Hostname())
})); err != nil {
log.Fatal(err)
@ -44,13 +47,14 @@ func main() {
ips = append(ips, net.ParseIP(strIP))
}
handler, err := NewDNSHandler(&memoryDNSCacher{TTL: time.Minute}, domainBlacklist, DNSResolver(ips), *cfg)
handler, err := NewDNSHandler(*httpAddress, &memoryDNSCacher{TTL: time.Minute}, domainBlacklist, DNSResolver(ips), *cfg)
if err != nil {
log.Fatal(err)
}
log.Println("DNS server listening")
srv := &dns.Server{Addr: ":53", Net: "udp", Handler: handler}
dnsAddr := fmt.Sprintf("%s:53", *dnsAddress)
log.Printf("DNS server listening @ %s", dnsAddr)
srv := &dns.Server{Addr: dnsAddr, Net: "udp", Handler: handler}
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}

Loading…
Cancel
Save