fix tests added http module

master
Adam Veldhousen 4 years ago
parent 2d3d6bea05
commit ec230678ed
Signed by: adam
GPG Key ID: 6DB29003C6DD1E4B

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"fmt"
"log"
"net"
@ -14,21 +15,38 @@ import (
//BlocklistManager manages block lists and determines if a domain should be permitted to be resolved normally
type BlocklistManager interface {
IsBlacklisted(string) (string, bool)
Reload(context.Context) error
}
func NewDomainBlacklist(blocklistURLS []string) (BlocklistManager, error) {
//NewDomainBlacklist creates a new BlocklistManager
func NewDomainBlacklist(blocklistURLs []string) (BlocklistManager, error) {
l := log.New(os.Stdout, "[Blocklist Manager] ", log.LUTC|log.Lshortfile)
bc := make(chan Blocklist, len(blocklistURLS))
bm := &memoryBlocklistManager{blocklistURLs: blocklistURLs, Logger: l}
if err := bm.Reload(context.Background()); err != nil {
return nil, err
}
return bm, nil
}
type memoryBlocklistManager struct {
blocklistURLs []string
blocklists []Blocklist
*log.Logger
}
func (mdb *memoryBlocklistManager) Reload(ctx context.Context) error {
var wg sync.WaitGroup
bc := make(chan Blocklist, len(mdb.blocklistURLs))
for _, bl := range blocklistURLS {
for _, bl := range mdb.blocklistURLs {
wg.Add(1)
go func(blocklistURL string) {
defer wg.Done()
l.Printf("loading block list: '%s'", blocklistURL)
blocklist, err := CreateBlockList(blocklistURL)
mdb.Printf("loading block list: '%s'", blocklistURL)
blocklist, err := CreateBlockList(ctx, blocklistURL)
if err != nil {
l.Printf("failed to load block list '%s': %v", blocklistURL, err)
mdb.Printf("failed to load block list '%s': %v", blocklistURL, err)
return
}
bc <- blocklist
@ -38,18 +56,13 @@ func NewDomainBlacklist(blocklistURLS []string) (BlocklistManager, error) {
wg.Wait()
close(bc)
blocklists := []Blocklist{}
mdb.blocklists = []Blocklist{}
for bl := range bc {
blocklists = append(blocklists, bl)
mdb.blocklists = append(mdb.blocklists, bl)
}
l.Printf("successfully loaded '%d' block lists", len(blocklists))
return &memoryBlocklistManager{blocklists: blocklists, Logger: l}, nil
}
type memoryBlocklistManager struct {
blocklists []Blocklist
*log.Logger
mdb.Printf("successfully loaded '%d' block lists", len(mdb.blocklists))
return nil
}
func (mdb *memoryBlocklistManager) IsBlacklisted(domain string) (string, bool) {
@ -69,13 +82,14 @@ type Blocklist struct {
}
// CreateBlockList creates a block list from the source URL.
func CreateBlockList(sourceURL string) (Blocklist, error) {
func CreateBlockList(ctx context.Context, sourceURL string) (Blocklist, error) {
blocklist := Blocklist{
Source: sourceURL,
Domains: map[string]string{},
}
response, err := http.Get(sourceURL)
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
response, err := http.DefaultClient.Do(req)
if err != nil {
return blocklist, err
}
@ -87,24 +101,29 @@ func CreateBlockList(sourceURL string) (Blocklist, error) {
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if line == "" || line[0] == '#' {
continue
}
select {
case <-ctx.Done():
return blocklist, context.Canceled
default:
line := scanner.Text()
if line == "" || line[0] == '#' {
continue
}
domain := line
if frags := strings.Split(line, " "); len(frags) > 1 {
if net.ParseIP(frags[1]) != nil {
domain := line
if frags := strings.Split(line, " "); len(frags) > 1 {
if net.ParseIP(frags[1]) != nil {
continue
}
domain = frags[1]
}
if strings.Contains(domain, "localhost") || strings.Contains(domain, "loopback") || domain == "broadcasthost" {
continue
}
domain = frags[1]
}
if strings.Contains(domain, "localhost") || strings.Contains(domain, "loopback") || domain == "broadcasthost" {
continue
blocklist.Domains[domain] = ""
}
blocklist.Domains[domain] = "enabled"
}
return blocklist, nil

@ -1,6 +1,7 @@
package main
import (
"context"
"reflect"
"testing"
)
@ -20,23 +21,23 @@ func TestFetchBlockList(t *testing.T) {
source: "https://gist.githubusercontent.com/adamveld12/7d7a236344c48d6e00cc4bd0d88079bf/raw/93b837ff443d1e4c4c9a381275733fd28d16b5f7/blacklist.txt",
want: Blocklist{
Source: "https://gist.githubusercontent.com/adamveld12/7d7a236344c48d6e00cc4bd0d88079bf/raw/93b837ff443d1e4c4c9a381275733fd28d16b5f7/blacklist.txt",
Domains: []string{
"local",
"ip6-localnet",
"ip6-mcastprefix",
"ip6-allnodes",
"ip6-allrouters",
"ip6-allhosts",
"01mspmd5yalky8.com",
"0byv9mgbn0.com",
"analytics.247sports.com",
"www.analytics.247sports.com",
"2no.co",
"www.2no.co",
"logitechlogitechglobal.112.2o7.net",
"www.logitechlogitechglobal.112.2o7.net",
"30-day-change.com",
"www.30-day-change.com",
Domains: map[string]string{
"local": "",
"ip6-localnet": "",
"ip6-mcastprefix": "",
"ip6-allnodes": "",
"ip6-allrouters": "",
"ip6-allhosts": "",
"01mspmd5yalky8.com": "",
"0byv9mgbn0.com": "",
"analytics.247sports.com": "",
"www.analytics.247sports.com": "",
"2no.co": "",
"www.2no.co": "",
"logitechlogitechglobal.112.2o7.net": "",
"www.logitechlogitechglobal.112.2o7.net": "",
"30-day-change.com": "",
"www.30-day-change.com": "",
},
},
},
@ -45,48 +46,48 @@ func TestFetchBlockList(t *testing.T) {
source: "https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt",
want: Blocklist{
Source: "https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt",
Domains: []string{
"adjust.io",
"airbrake.io",
"appboy.com",
"appsflyer.com",
"apsalar.com",
"bango.combango.org",
"bango.net",
"basic-check.disconnect.me",
"bkrtx.com",
"bluekai.com",
"bugsense.com",
"burstly.com",
"chartboost.com",
"count.ly",
"crashlytics.com",
"crittercism.com",
"custom-blacklisted-tracking-example.com",
"do-not-tracker.org",
"eviltracker.net",
"flurry.com",
"getexceptional.com",
"inmobi.com",
"jumptap.com",
"localytics.com",
"mixpanel.com",
"mobile-collector.newrelic.com",
"mobileapptracking.com",
"playtomic.com",
"stathat.com",
"supercell.net",
"tapjoy.com",
"trackersimulator.org",
"usergrid.com",
"vungle.com",
Domains: map[string]string{
"adjust.io": "",
"airbrake.io": "",
"appboy.com": "",
"appsflyer.com": "",
"apsalar.com": "",
"bango.combango.org": "",
"bango.net": "",
"basic-check.disconnect.me": "",
"bkrtx.com": "",
"bluekai.com": "",
"bugsense.com": "",
"burstly.com": "",
"chartboost.com": "",
"count.ly": "",
"crashlytics.com": "",
"crittercism.com": "",
"custom-blacklisted-tracking-example.com": "",
"do-not-tracker.org": "",
"eviltracker.net": "",
"flurry.com": "",
"getexceptional.com": "",
"inmobi.com": "",
"jumptap.com": "",
"localytics.com": "",
"mixpanel.com": "",
"mobile-collector.newrelic.com": "",
"mobileapptracking.com": "",
"playtomic.com": "",
"stathat.com": "",
"supercell.net": "",
"tapjoy.com": "",
"trackersimulator.org": "",
"usergrid.com": "",
"vungle.com": "",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CreateBlockList(tt.source)
got, err := CreateBlockList(context.Background(), tt.source)
if (err != nil) != tt.wantErr {
t.Errorf("FetchBlockList() error = %v, wantErr %v", err, tt.wantErr)
return

@ -0,0 +1,47 @@
package main
import (
"log"
"net/http"
"os"
"strings"
)
//NewHTTPHandler creates a new http handler
func NewHTTPHandler(bm BlocklistManager) http.Handler {
return newAPIHandler(nil, bm)
}
func newAPIHandler(inner http.Handler, bm BlocklistManager) http.Handler {
l := log.New(os.Stdout, "[HTTP API] ", log.LUTC|log.Lshortfile)
apiHandler := http.NewServeMux()
/*
1. update block lists
*/
apiHandler.HandleFunc("/api/gopherhole/blocklists/reload", func(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
http.Error(w, "method not allowed, must be a POST", http.StatusMethodNotAllowed)
return
}
if err := bm.Reload(req.Context()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
l.Println("reloaded blocklists successfully")
})
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
if strings.HasPrefix(req.Host, "/api/gopherhole") {
apiHandler.ServeHTTP(res, req)
return
}
handleDefaultRequests(res, req)
})
}
func handleDefaultRequests(res http.ResponseWriter, req *http.Request) {
log.Printf("got request for: %s", req.URL.Hostname())
}

@ -28,21 +28,21 @@ func main() {
log.Fatal(err)
}
go func() {
httpAddr := fmt.Sprintf("%s:80", *httpAddress)
log.Printf("HTTP server listening @ %s", httpAddr)
if err := http.ListenAndServe(httpAddr, http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
log.Printf("got request for: %s", req.URL.Hostname())
})); err != nil {
log.Fatal(err)
}
}()
domainBlacklist, err := NewDomainBlacklist(cfg.Blocklists)
if err != nil {
log.Fatal(err)
}
go func(bm BlocklistManager) {
httpAddr := fmt.Sprintf("%s:80", *httpAddress)
log.Printf("HTTP server listening @ %s", httpAddr)
handler := NewHTTPHandler(bm)
if err := http.ListenAndServe(httpAddr, handler); err != nil {
log.Fatal(err)
}
}(domainBlacklist)
ips := []net.IP{}
for _, strIP := range cfg.Upstream {
ips = append(ips, net.ParseIP(strIP))
@ -55,7 +55,12 @@ func main() {
dnsAddr := fmt.Sprintf("%s:53", *dnsAddress)
log.Printf("DNS server listening @ %s", dnsAddr)
srv := &dns.Server{Addr: dnsAddr, Net: "udp", Handler: handler}
srv := &dns.Server{
Addr: dnsAddr,
Net: "udp",
Handler: handler,
ReadTimeout: time.Second * 3,
}
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}

@ -15,6 +15,10 @@ package:
publish:
@docker push vdhsn/gopherhole
test:
go test -v -cover ./...
$(BIN)/$(APP):
go build -v -o $(BIN)/$(APP) .
@ -27,4 +31,4 @@ $(BIN):
clean:
@rm -rf $(BIN) gohperhole debug.test _debug_bin
.PHONY: ci clean deploy dev package publish
.PHONY: ci clean deploy dev package publish test
Loading…
Cancel
Save