You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
3.0 KiB
135 lines
3.0 KiB
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
//BlocklistManager manages block lists and determines if a domain should be permitted to be resolved normally
|
|
type BlocklistManager interface {
|
|
IsBlacklisted(string) (string, bool)
|
|
Reload(context.Context) error
|
|
}
|
|
|
|
//NewDomainBlacklist creates a new BlocklistManager
|
|
func NewDomainBlacklist(blocklistURLs []string) (BlocklistManager, error) {
|
|
l := log.New(os.Stdout, "[Blocklist Manager] ", log.LUTC|log.Lshortfile)
|
|
bm := &memoryBlocklistManager{blocklistURLs: blocklistURLs, Logger: l}
|
|
go bm.Reload(context.Background())
|
|
|
|
// if err := bm.Reload(context.Background()); err != nil {
|
|
// return nil, err
|
|
// }
|
|
|
|
return bm, nil
|
|
}
|
|
|
|
type memoryBlocklistManager struct {
|
|
blocklistURLs []string
|
|
blocklists []Blocklist
|
|
*log.Logger
|
|
}
|
|
|
|
func (mdb *memoryBlocklistManager) Reload(ctx context.Context) error {
|
|
var wg sync.WaitGroup
|
|
bc := make(chan Blocklist, len(mdb.blocklistURLs))
|
|
|
|
for _, bl := range mdb.blocklistURLs {
|
|
wg.Add(1)
|
|
go func(blocklistURL string) {
|
|
defer wg.Done()
|
|
mdb.Printf("loading block list: '%s'", blocklistURL)
|
|
blocklist, err := CreateBlockList(ctx, blocklistURL)
|
|
if err != nil {
|
|
mdb.Printf("failed to load block list '%s': %v", blocklistURL, err)
|
|
return
|
|
}
|
|
bc <- blocklist
|
|
}(bl)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(bc)
|
|
|
|
mdb.blocklists = []Blocklist{}
|
|
count := 0
|
|
for bl := range bc {
|
|
mdb.blocklists = append(mdb.blocklists, bl)
|
|
count += len(bl.Domains)
|
|
}
|
|
|
|
mdb.Printf("successfully loaded %d block lists totaling %d domains", len(mdb.blocklists), count)
|
|
return nil
|
|
}
|
|
|
|
func (mdb *memoryBlocklistManager) IsBlacklisted(domain string) (string, bool) {
|
|
for _, blocklist := range mdb.blocklists {
|
|
if _, ok := blocklist.Domains[domain]; ok {
|
|
return blocklist.Source, true
|
|
}
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
// Blocklist is a list of domains that should be blocked
|
|
type Blocklist struct {
|
|
Source string
|
|
Domains map[string]string
|
|
}
|
|
|
|
// CreateBlockList creates a block list from the source URL.
|
|
func CreateBlockList(ctx context.Context, sourceURL string) (Blocklist, error) {
|
|
blocklist := Blocklist{
|
|
Source: sourceURL,
|
|
Domains: map[string]string{},
|
|
}
|
|
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
|
|
response, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return blocklist, err
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
return blocklist, fmt.Errorf("non 200 status: %d", response.StatusCode)
|
|
}
|
|
|
|
scanner := bufio.NewScanner(response.Body)
|
|
for scanner.Scan() {
|
|
select {
|
|
case <-ctx.Done():
|
|
return blocklist, context.Canceled
|
|
default:
|
|
line := scanner.Text()
|
|
if line == "" || line[0] == '#' {
|
|
continue
|
|
}
|
|
|
|
domain := line
|
|
if frags := strings.Split(line, " "); len(frags) > 1 {
|
|
if net.ParseIP(frags[1]) != nil {
|
|
continue
|
|
}
|
|
domain = frags[1]
|
|
}
|
|
|
|
if strings.Contains(domain, "localhost") || strings.Contains(domain, "loopback") || domain == "broadcasthost" {
|
|
continue
|
|
}
|
|
|
|
blocklist.Domains[domain] = ""
|
|
}
|
|
}
|
|
|
|
return blocklist, nil
|
|
}
|