package main import ( "bufio" "context" "fmt" "log" "net" "net/http" "os" "strings" "sync" ) //BlocklistManager manages block lists and determines if a domain should be permitted to be resolved normally type BlocklistManager interface { IsBlacklisted(string) (string, bool) Reload(context.Context) error } //NewDomainBlacklist creates a new BlocklistManager func NewDomainBlacklist(blocklistURLs []string) (BlocklistManager, error) { l := log.New(os.Stdout, "[Blocklist Manager] ", log.LUTC|log.Lshortfile) bm := &memoryBlocklistManager{blocklistURLs: blocklistURLs, Logger: l} go bm.Reload(context.Background()) // if err := bm.Reload(context.Background()); err != nil { // return nil, err // } return bm, nil } type memoryBlocklistManager struct { blocklistURLs []string blocklists []Blocklist *log.Logger } func (mdb *memoryBlocklistManager) Reload(ctx context.Context) error { var wg sync.WaitGroup bc := make(chan Blocklist, len(mdb.blocklistURLs)) for _, bl := range mdb.blocklistURLs { wg.Add(1) go func(blocklistURL string) { defer wg.Done() mdb.Printf("loading block list: '%s'", blocklistURL) blocklist, err := CreateBlockList(ctx, blocklistURL) if err != nil { mdb.Printf("failed to load block list '%s': %v", blocklistURL, err) return } bc <- blocklist }(bl) } wg.Wait() close(bc) mdb.blocklists = []Blocklist{} count := 0 for bl := range bc { mdb.blocklists = append(mdb.blocklists, bl) count += len(bl.Domains) } mdb.Printf("successfully loaded %d block lists totaling %d domains", len(mdb.blocklists), count) return nil } func (mdb *memoryBlocklistManager) IsBlacklisted(domain string) (string, bool) { for _, blocklist := range mdb.blocklists { if _, ok := blocklist.Domains[domain]; ok { return blocklist.Source, true } } return "", false } // Blocklist is a list of domains that should be blocked type Blocklist struct { Source string Domains map[string]string } // CreateBlockList creates a block list from the source URL. func CreateBlockList(ctx context.Context, sourceURL string) (Blocklist, error) { blocklist := Blocklist{ Source: sourceURL, Domains: map[string]string{}, } req, _ := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil) response, err := http.DefaultClient.Do(req) if err != nil { return blocklist, err } defer response.Body.Close() if response.StatusCode != http.StatusOK { return blocklist, fmt.Errorf("non 200 status: %d", response.StatusCode) } scanner := bufio.NewScanner(response.Body) for scanner.Scan() { select { case <-ctx.Done(): return blocklist, context.Canceled default: line := scanner.Text() if line == "" || line[0] == '#' { continue } domain := line if frags := strings.Split(line, " "); len(frags) > 1 { if net.ParseIP(frags[1]) != nil { continue } domain = frags[1] } if strings.Contains(domain, "localhost") || strings.Contains(domain, "loopback") || domain == "broadcasthost" { continue } blocklist.Domains[domain] = "" } } return blocklist, nil }