fixup code a bunch

master
Adam Veldhousen 4 years ago
parent bb75c421a6
commit 8c3dd8d8e3
Signed by: adam
GPG Key ID: 6DB29003C6DD1E4B

@ -5,14 +5,17 @@ WORKDIR /opt/build
COPY . /opt/build
RUN go get -v && go build -v -o gopherhole ./main.go
RUN go get -v && go build -v -o gopherhole .
FROM ubuntu
COPY --from=build /opt/build/gopherhole /opt/gopherhole/gopherhole
COPY ./config.json /opt/gopherhole/config.json
ENV GOPHERHOLE_BIND_ADDRESS 127.0.0.1
WORKDIR /opt/gopherhole
EXPOSE 53
EXPOSE 53/udp 53/tcp 80/tcp
ENTRYPOINT /opt/gopherhole/gopherhole

@ -3,23 +3,79 @@ package main
import (
"bufio"
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
)
//BlocklistManager manages block lists and determines if a domain should be permitted to be resolved normally
type BlocklistManager interface {
IsBlacklisted(string) (string, bool)
}
func NewDomainBlacklist(blocklistURLS []string) (BlocklistManager, error) {
l := log.New(os.Stdout, "[Blocklist Manager] ", log.LUTC|log.Lshortfile)
bc := make(chan Blocklist, len(blocklistURLS))
var wg sync.WaitGroup
for _, bl := range blocklistURLS {
wg.Add(1)
go func(blocklistURL string) {
defer wg.Done()
l.Printf("loading block list: '%s'", blocklistURL)
blocklist, err := CreateBlockList(blocklistURL)
if err != nil {
l.Printf("failed to load block list '%s': %v", blocklistURL, err)
return
}
bc <- blocklist
}(bl)
}
wg.Wait()
close(bc)
blocklists := []Blocklist{}
for bl := range bc {
blocklists = append(blocklists, bl)
}
l.Printf("successfully loaded '%d' block lists", len(blocklists))
return &memoryBlocklistManager{blocklists: blocklists, Logger: l}, nil
}
type memoryBlocklistManager struct {
blocklists []Blocklist
*log.Logger
}
func (mdb *memoryBlocklistManager) IsBlacklisted(domain string) (string, bool) {
for _, blocklist := range mdb.blocklists {
if _, ok := blocklist.Domains[domain]; ok {
return blocklist.Source, true
}
}
return "", false
}
// Blocklist is a list of domains that should be blocked
type Blocklist struct {
Source string
Domains map[string]string
}
func FetchBlockList(source string) (Blocklist, error) {
// CreateBlockList creates a block list from the source URL.
func CreateBlockList(sourceURL string) (Blocklist, error) {
blocklist := Blocklist{
Source: source,
Source: sourceURL,
Domains: map[string]string{},
}
response, err := http.Get(source)
response, err := http.Get(sourceURL)
if err != nil {
return blocklist, err
}

@ -3,6 +3,8 @@ package main
import (
"encoding/json"
"io/ioutil"
"net"
"os"
)
func LoadConfig(path string) (*Configuration, error) {
@ -16,13 +18,21 @@ func LoadConfig(path string) (*Configuration, error) {
return nil, err
}
bindAddress := os.Getenv("GOPHERHOLE_BIND_ADDRESS")
if ip := net.ParseIP(bindAddress); ip != nil {
c.BindAddress = ip
} else {
c.BindAddress = net.ParseIP("127.0.0.1")
}
return &c, nil
}
type Configuration struct {
Upstream []string
Blocklists []string
Records map[string]ConfigRecord
BindAddress net.IP
Upstream []string
Blocklists []string
Records map[string]ConfigRecord
}
type ConfigRecord struct {

@ -5,12 +5,12 @@
],
"records": {
"vdhsn.com": {
"Type": "A",
"Record": "192.168.0.4"
"Type": "CNAME",
"Record": "internal.veldhousen.ninja"
},
"riffraff.vdhsn.com": {
"Type": "A",
"Record": "192.168.0.4"
"Type": "CNAME",
"Record": "internal.veldhousen.ninja"
}
},
"blocklists": [

@ -0,0 +1,97 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: gopherhole
labels:
app: gopherhole
spec:
selector:
matchLabels:
app: gopherhole
revisionHistoryLimit: 2
replicas: 1
strategy:
rollingUpdate:
maxSurge: 100%
template:
metadata:
labels:
app: gopherhole
spec:
restartPolicy: Always
containers:
- image: vdhsn/gopherhole:latest
imagePullPolicy: Always
name: gopherhole
env:
- name: GOPHERHOLE_BIND_ADDRESS
value: 192.168.0.98
stdin: true
tty: true
resources:
limits:
cpu: "512m"
memory: "512M"
requests:
cpu: "512m"
memory: "512m"
ports:
- containerPort: 53
protocol: UDP
name: dns-udp
- containerPort: 53
protocol: TCP
name: dns-tcp
- containerPort: 80
name: http-sinkhole
securityContext:
capabilities:
add: ["NET_ADMIN", "NET_RAW", "NET_BIND_SERVICE"]
---
apiVersion: v1
kind: Service
metadata:
name: gopherhole-http
spec:
type: LoadBalancer
externalTrafficPolicy: Local
externalIPs:
- 192.168.0.98
ports:
- port: 80
name: web
protocol: TCP
targetPort: 80
selector:
app: gopherhole
---
apiVersion: v1
kind: Service
metadata:
name: gopherhole-tcp-dns
spec:
type: LoadBalancer
externalTrafficPolicy: Local
externalIPs:
- 192.168.0.98
ports:
- port: 53
name: tcp-dns
protocol: TCP
selector:
app: gopherhole
---
apiVersion: v1
kind: Service
metadata:
name: gopherhole-udp-dns
spec:
type: LoadBalancer
externalIPs:
- 192.168.0.98
ports:
- port: 53
name: udp-dns
protocol: UDP
selector:
app: gopherhole

@ -0,0 +1,42 @@
apiVersion: traefik.containo.us/v1alpha1
kind: IngressRoute
metadata:
name: gopherhole-ingress-tls
namespace: default
spec:
entryPoints:
- websecure
routes:
- match: Host(`.*`)
priority: 999
kind: Rule
services:
- name: gopherhole-http
port: 80
middlewares:
- name: internal-only
tls:
certResolver: default
---
apiVersion: traefik.containo.us/v1alpha1
kind: IngressRoute
metadata:
name: gopherhole-admin-ingress-tls
namespace: default
spec:
entryPoints:
- websecure
routes:
- match: Host(`gopherhole.veldhousen.ninja`,`gopherhole.vdhsn.com`,`gopherhole.veldhousen.com`,`gopherhole.veldhousen.net`)
kind: Rule
priority: 1
services:
- name: gopherhole-http
port: 80
middlewares:
- name: internal-only
- name: ssl-redirect-header
- name: gzip
tls:
certResolver: default

170
dns.go

@ -0,0 +1,170 @@
package main
import (
"errors"
"fmt"
"log"
"net"
"os"
"sync"
"time"
"github.com/miekg/dns"
)
type dnsHandler struct {
Config Configuration
logger *log.Logger
Resolver DNSResolver
Blocklist BlocklistManager
Cache DNSCacher
}
//NewDNSHandler creates a new DNS server handler
func NewDNSHandler(cache DNSCacher, blocklist BlocklistManager, r DNSResolver, cfg Configuration) (dns.Handler, error) {
l := log.New(os.Stdout, "[DNS Server] ", log.LUTC|log.Lmicroseconds|log.Lshortfile)
return &dnsHandler{
logger: l,
Resolver: r,
Config: cfg,
Blocklist: blocklist,
Cache: cache,
}, nil
}
func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
var msg dns.Msg
msg.Authoritative = true
msg.RecursionAvailable = true
msg.SetReply(r)
for _, question := range r.Question {
rawDomain := question.Name
domain := rawDomain[:len(rawDomain)-1]
header := dns.RR_Header{
Name: domain,
Rrtype: question.Qtype,
Class: dns.ClassINET,
Ttl: 60,
}
if address, ok := h.Config.Records[domain]; ok {
switch address.Type {
case "A":
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: net.ParseIP(address.Record)})
case "CNAME":
answer, _, err := h.Resolver.Resolve(fmt.Sprintf("%s.", address.Record))
if err != nil {
h.logger.Printf("could not resolve custom CNAME record: %v", err)
continue
}
h.logger.Printf("%v", answer)
msg.Answer = append(msg.Answer, answer...)
default:
h.logger.Printf("type '%s' not supported for custom records.", address.Type)
}
} else if list, block := h.Blocklist.IsBlacklisted(domain); block {
h.logger.Printf("blocked dns query for '%s' from list '%s'", domain, list)
msg.Answer = append(msg.Answer, &dns.A{Hdr: header, A: net.ParseIP("127.0.0.1")})
} else if cacheAnswers, ok := h.Cache.Get(domain); ok && len(cacheAnswers) > 0 {
msg.Answer = append(msg.Answer, cacheAnswers...)
} else {
msg.Authoritative = false
answers, ttl, err := h.Resolver.Resolve(rawDomain)
if err != nil {
h.logger.Printf("error resolving '%s': %v", domain, err)
continue
}
msg.Answer = append(msg.Answer, answers...)
h.Cache.Set(domain, ttl, answers)
}
}
w.WriteMsg(&msg)
}
func shouldBlock(bls []Blocklist, domain string) (string, bool) {
for _, b := range bls {
if _, ok := b.Domains[domain[:len(domain)-1]]; ok {
return b.Source, true
}
}
return "", false
}
//DNSResolver resolves dns recursively from index 0 to n
type DNSResolver []net.IP
//Resolve resolves the domain specified
func (resolver DNSResolver) Resolve(domain string) ([]dns.RR, time.Duration, error) {
c := dns.Client{
DialTimeout: time.Second,
ReadTimeout: time.Second,
}
for _, server := range resolver {
m := dns.Msg{}
m.SetQuestion(domain, dns.TypeA)
m.Compress = true
r, t, err := c.Exchange(&m, fmt.Sprintf("%s:53", server))
if err != nil || len(r.Answer) == 0 {
// try another server
log.Println("GOT AN ERR RESOLVING", err)
continue
}
log.Printf("Took %v", t)
return r.Answer, time.Second * 60, nil
}
return nil, 0, errors.New("no record found")
}
//DNSCacher is a cache for dns records
type DNSCacher interface {
Get(domain string) ([]dns.RR, bool)
Set(domain string, ttl time.Duration, answers []dns.RR)
}
type dnsCacheRecord struct {
Expiration time.Time
Records []dns.RR
}
type memoryDNSCacher struct {
TTL time.Duration
cache map[string]dnsCacheRecord
sync.Mutex
}
func (d *memoryDNSCacher) Get(domain string) ([]dns.RR, bool) {
d.Lock()
defer d.Unlock()
if d.cache == nil {
d.cache = map[string]dnsCacheRecord{}
}
if record, ok := d.cache[domain]; ok && time.Now().UTC().Before(record.Expiration) {
return record.Records, true
}
return nil, false
}
func (d *memoryDNSCacher) Set(domain string, ttl time.Duration, records []dns.RR) {
d.Lock()
defer d.Unlock()
if d.cache == nil {
d.cache = map[string]dnsCacheRecord{}
}
if records != nil && len(records) > 0 && ttl > 0 {
log.Printf("setting %s to cache", domain)
d.cache[domain] = dnsCacheRecord{
Expiration: time.Now().UTC().Add(ttl),
Records: records,
}
} else {
delete(d.cache, domain)
}
}

@ -1,172 +1,57 @@
package main
import (
"errors"
"fmt"
"flag"
"log"
"net"
"net/http"
"time"
"github.com/miekg/dns"
)
var (
configFilePath = flag.String("config", "./config.json", "config file path")
httpAddress = flag.String("bind-http", "127.0.0.1", "interface to bind the HTTP server to")
dnsAddress = flag.String("bind-dns", "127.0.0.1", "interface to bind the DNS server to")
verbose = flag.Bool("verbose", false, "enable verbose logging")
)
func main() {
fmt.Println("hello")
flag.Parse()
cfg, err := LoadConfig("./config.json")
cfg, err := LoadConfig(*configFilePath)
if err != nil {
log.Fatal(err)
}
blockLists := []Blocklist{}
for _, bl := range cfg.Blocklists {
blocklist, err := FetchBlockList(bl)
if err != nil {
continue
go func() {
log.Println("HTTP server listening")
if err := http.ListenAndServe(":80", http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
log.Printf("got request for: %s", req.URL.Hostname())
})); err != nil {
log.Fatal(err)
}
}()
log.Printf("block list '%s' -> %d", bl, len(blocklist.Domains))
blockLists = append(blockLists, blocklist)
}
srv := &dns.Server{
Addr: ":53",
Net: "udp",
Handler: &handler{
Config: *cfg,
Cache: DNSCache(map[string]DNSCacheRecord{}),
BlockLists: blockLists,
},
}
if err := srv.ListenAndServe(); err != nil {
domainBlacklist, err := NewDomainBlacklist(cfg.Blocklists)
if err != nil {
log.Fatal(err)
}
}
type handler struct {
Config Configuration
BlockLists []Blocklist
Cache DNSCache
}
type DNSCacheRecord struct {
Expiration time.Time
IP net.IP
}
type DNSCache map[string]DNSCacheRecord
func (d DNSCache) Get(domain string, msg *dns.Msg) bool {
if record, ok := d[domain]; ok && time.Now().UTC().Before(record.Expiration) {
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: record.IP,
})
return true
}
return false
}
func (d DNSCache) Set(domain string, IP net.IP, ttl time.Duration) {
if IP != nil && ttl > 0 {
log.Printf("setting %s to cache", domain)
d[domain] = DNSCacheRecord{
Expiration: time.Now().UTC().Add(ttl),
IP: IP,
}
} else {
delete(d, domain)
ips := []net.IP{}
for _, strIP := range cfg.Upstream {
ips = append(ips, net.ParseIP(strIP))
}
}
func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg := dns.Msg{}
msg.SetReply(r)
switch r.Question[0].Qtype {
case dns.TypeA:
msg.Authoritative = true
domain := msg.Question[0].Name
fmt.Println("got request", domain)
if address, ok := h.Config.Records[domain[:len(domain)-1]]; ok && address.Type == "A" {
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(address.Record),
})
} else if list, block := shouldBlock(h.BlockLists, domain); block {
log.Printf("blocked dns query for '%s' from '%s'", domain, list)
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("127.0.0.1"),
})
} else if !h.Cache.Get(domain, &msg) {
msg, ip, ttl, err := recursiveResolve(domain, h.Config.Upstream...)
if err != nil {
log.Printf("got an error trying to resolve '%s': %v", domain, err)
break
}
h.Cache.Set(domain, ip, ttl)
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}
w.WriteMsg(&msg)
}
func shouldBlock(bls []Blocklist, domain string) (string, bool) {
for _, b := range bls {
if _, ok := b.Domains[domain[:len(domain)-1]]; ok {
return b.Source, true
}
}
return "", false
}
func recursiveResolve(domain string, servers ...string) (*dns.Msg, net.IP, time.Duration, error) {
c := dns.Client{
DialTimeout: time.Second,
ReadTimeout: time.Second,
handler, err := NewDNSHandler(&memoryDNSCacher{TTL: time.Minute}, domainBlacklist, DNSResolver(ips), *cfg)
if err != nil {
log.Fatal(err)
}
for _, server := range servers {
m := dns.Msg{}
m.SetQuestion(domain, dns.TypeA)
m.Compress = true
r, t, err := c.Exchange(&m, server+":53")
if err != nil {
return nil, nil, 0, err
}
log.Printf("Took %v", t)
if len(r.Answer) == 0 {
continue
}
res := dns.Msg{}
for _, ans := range r.Answer {
switch ans.(type) {
case *dns.A:
Arecord := ans.(*dns.A)
res.Answer = append(res.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: Arecord.A,
})
return &res, Arecord.A, time.Second * time.Duration(Arecord.Hdr.Ttl), nil
case *dns.CNAME:
Crecord := ans.(*dns.CNAME)
res.Answer = append(res.Answer, &dns.CNAME{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 60},
Target: Crecord.Target,
})
return &res, nil, 0, nil
}
}
log.Println("DNS server listening")
srv := &dns.Server{Addr: ":53", Net: "udp", Handler: handler}
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}
return nil, nil, 0, errors.New("no record found")
}

@ -4,6 +4,17 @@ BIN := .bin
dev: clean $(BIN)/config.json $(BIN)/$(APP)
sudo $(BIN)/$(APP)
ci: clean package publish deploy
deploy:
@kubectl apply -f ./deployment
package:
@docker build -t vdhsn/gopherhole .
publish:
@docker push vdhsn/gopherhole
$(BIN)/$(APP):
go build -v -o $(BIN)/$(APP) .
@ -14,6 +25,6 @@ $(BIN):
mkdir -p $@
clean:
@rm -rf $(BIN)
@rm -rf $(BIN) gohperhole debug.test _debug_bin
.PHONY: dev clean
.PHONY: ci clean deploy dev package publish
Loading…
Cancel
Save