You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gohperhole/blocklists.go

135 lines
3.0 KiB

package main
import (
"bufio"
"context"
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
)
//BlocklistManager manages block lists and determines if a domain should be permitted to be resolved normally
type BlocklistManager interface {
IsBlacklisted(string) (string, bool)
Reload(context.Context) error
}
//NewDomainBlacklist creates a new BlocklistManager
func NewDomainBlacklist(blocklistURLs []string) (BlocklistManager, error) {
l := log.New(os.Stdout, "[Blocklist Manager] ", log.LUTC|log.Lshortfile)
bm := &memoryBlocklistManager{blocklistURLs: blocklistURLs, Logger: l}
go bm.Reload(context.Background())
// if err := bm.Reload(context.Background()); err != nil {
// return nil, err
// }
return bm, nil
}
type memoryBlocklistManager struct {
blocklistURLs []string
blocklists []Blocklist
*log.Logger
}
func (mdb *memoryBlocklistManager) Reload(ctx context.Context) error {
var wg sync.WaitGroup
bc := make(chan Blocklist, len(mdb.blocklistURLs))
for _, bl := range mdb.blocklistURLs {
wg.Add(1)
go func(blocklistURL string) {
defer wg.Done()
mdb.Printf("loading block list: '%s'", blocklistURL)
blocklist, err := CreateBlockList(ctx, blocklistURL)
if err != nil {
mdb.Printf("failed to load block list '%s': %v", blocklistURL, err)
return
}
bc <- blocklist
}(bl)
}
wg.Wait()
close(bc)
mdb.blocklists = []Blocklist{}
count := 0
for bl := range bc {
mdb.blocklists = append(mdb.blocklists, bl)
count += len(bl.Domains)
}
mdb.Printf("successfully loaded %d block lists totaling %d domains", len(mdb.blocklists), count)
return nil
}
func (mdb *memoryBlocklistManager) IsBlacklisted(domain string) (string, bool) {
for _, blocklist := range mdb.blocklists {
if _, ok := blocklist.Domains[domain]; ok {
return blocklist.Source, true
}
}
return "", false
}
// Blocklist is a list of domains that should be blocked
type Blocklist struct {
Source string
Domains map[string]string
}
// CreateBlockList creates a block list from the source URL.
func CreateBlockList(ctx context.Context, sourceURL string) (Blocklist, error) {
blocklist := Blocklist{
Source: sourceURL,
Domains: map[string]string{},
}
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
response, err := http.DefaultClient.Do(req)
if err != nil {
return blocklist, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return blocklist, fmt.Errorf("non 200 status: %d", response.StatusCode)
}
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
select {
case <-ctx.Done():
return blocklist, context.Canceled
default:
line := scanner.Text()
if line == "" || line[0] == '#' {
continue
}
domain := line
if frags := strings.Split(line, " "); len(frags) > 1 {
if net.ParseIP(frags[1]) != nil {
continue
}
domain = frags[1]
}
if strings.Contains(domain, "localhost") || strings.Contains(domain, "loopback") || domain == "broadcasthost" {
continue
}
blocklist.Domains[domain] = ""
}
}
return blocklist, nil
}