added log and aggregate http endpoints

pull/1/head
Adam Veldhousen 3 years ago
parent ee6b8def4a
commit 492d522d24
Signed by: adam
GPG Key ID: 6DB29003C6DD1E4B

3
.gitignore vendored

@ -1 +1,2 @@
.bin
.bin
.idea

@ -3,6 +3,9 @@ module github.com/adamveld12/gopherhole
go 1.15
require (
github.com/go-chi/chi v1.5.4
github.com/go-chi/chi/v5 v5.0.2
github.com/go-chi/cors v1.2.0
github.com/mattn/go-sqlite3 v1.14.7
github.com/miekg/dns v1.1.41
)

@ -1,3 +1,9 @@
github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs=
github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg=
github.com/go-chi/chi/v5 v5.0.2 h1:4xKeALZdMEsuI5s05PU2Bm89Uc5iM04qFubUCl5LfAQ=
github.com/go-chi/chi/v5 v5.0.2/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/cors v1.2.0 h1:tV1g1XENQ8ku4Bq3K9ub2AtgG+p16SmzeMSGTwrOKdE=
github.com/go-chi/cors v1.2.0/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=

@ -0,0 +1,8 @@
SELECT
domain,
ROUND(AVG(totalTimeMs), 3) as averageTotalTime,
COUNT(*) as requests,
strftime('%s', started)/(5*60) as "timeWindow"
FROM log
GROUP BY domain, strftime('%s', started) / (5 * 60)
ORDER BY started desc;

@ -8,6 +8,8 @@ import (
)
type Cache interface {
Purge(name string)
PurgeAll()
LookupRecord(name string) []dns.RR
SaveAnswers(name string, answers []dns.RR)
}
@ -20,9 +22,10 @@ type Memory struct {
}
}
func (m *Memory) init() {
if m.cache == nil {
func (m *Memory) init(clear bool) {
if m.cache == nil || clear {
m.Lock()
m.cache = nil
m.cache = make(map[string]struct {
Answers []dns.RR
Expiration time.Time
@ -31,8 +34,16 @@ func (m *Memory) init() {
}
}
func (m *Memory) PurgeAll() { m.init(true) }
func (m *Memory) Purge(name string) {
m.Lock()
defer m.Unlock()
delete(m.cache, name)
}
func (m *Memory) LookupRecord(name string) []dns.RR {
m.init()
m.init(false)
if v, ok := m.cache[name]; ok && time.Until(v.Expiration) > 0 {
return v.Answers
@ -45,7 +56,7 @@ func (m *Memory) SaveAnswers(name string, answers []dns.RR) {
if answers == nil || name == "" {
return
}
m.init()
m.init(false)
m.Lock()
defer m.Unlock()

@ -2,6 +2,7 @@ package internal
import (
"log"
"strings"
"time"
"github.com/miekg/dns"
@ -18,48 +19,36 @@ func (dm *DomainManager) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// only grab first question: https://stackoverflow.com/questions/4082081/requesting-a-and-aaaa-records-in-single-dns-query/4083071#4083071
start := time.Now()
q := r.Question[0]
responseMessage := new(dns.Msg)
ql := QueryLog{
Started: start.UTC(),
Protocol: w.RemoteAddr().Network(),
ClientIP: w.RemoteAddr().String(),
ClientIP: strings.Split(w.RemoteAddr().String(), ":")[0],
Domain: q.Name,
Status: NoAnswer,
}
var responseMessage *dns.Msg
var err error
var resolved Resolved
// lookup in cache
if dest := dm.LookupRecord(q.Name); dest != nil {
responseMessage = new(dns.Msg)
responseMessage.Answer = dest
ql.Status = CacheHit
}
// lookup in rules engine
if rule, ok := dm.Evaluate(q.Name); ok {
} else if rule, ok := dm.Evaluate(q.Name); ok {
// evaluate domain in rules engine
responseMessage = rule.CreateAnswer(q.Name)
responseMessage.Authoritative = true
ql.Status = CustomRule
}
// recurse
if responseMessage == nil {
var resolved Resolved
if resolved, err = dm.Recursors.Resolve(r); err == nil {
dm.SaveAnswers(q.Name, resolved.Message.Answer)
ql.RecurseUpstreamIP = resolved.UpstreamUsed
ql.RecurseRoundTripTimeMs = int(resolved.RoundtripTime.Milliseconds())
ql.Status = RecursedUpstream
responseMessage = resolved.Message
} else {
ql.Status = NoAnswer
responseMessage = new(dns.Msg)
}
} else if resolved, ql.Error = dm.Recursors.Resolve(r); ql.Error == nil {
dm.SaveAnswers(q.Name, resolved.Message.Answer)
responseMessage = resolved.Message
ql.RecurseUpstreamIP = resolved.UpstreamUsed
ql.RecurseRoundTripTimeMs = int(resolved.RoundtripTime.Milliseconds())
ql.Status = RecursedUpstream
}
responseMessage.SetReply(r)
responseMessage.RecursionAvailable = true
responseMessage.Compress = true
ql.TotalTimeMs = int(time.Since(start).Milliseconds())

@ -0,0 +1,236 @@
package internal
import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
)
type adminHandler struct {
Cache
Storage
*RuleEngine
h http.Handler
}
func NewAdminHandler(c Cache, s Storage, re *RuleEngine) http.Handler {
handler := chi.NewRouter()
a := &adminHandler{
Cache: c,
Storage: s,
RuleEngine: re,
h: handler,
}
handler.Use(middleware.RequestID)
handler.Use(middleware.RealIP)
handler.Use(middleware.Logger)
handler.Use(middleware.AllowContentType("application/json; utf-8"))
handler.Use(middleware.Timeout(time.Second * 10))
handler.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"http://*", "https://*"},
AllowedMethods: []string{"GET", "PUT", "DELETE", "POST", "OPTIONS"},
AllowedHeaders: []string{"Content-Type", "Accept"},
AllowCredentials: false,
MaxAge: 300,
}))
handler.Use(middleware.Recoverer)
handler.Route("/api/v1", func(r chi.Router) {
r.Get("/metrics/log", RestHandler(a.getLog).ToHF())
r.Get("/metrics/stats", RestHandler(a.getStats).ToHF())
// r.Delete("/cache/purgeall", RestHandler(a.purgeAll).ToHF())
// r.Delete("/cache/purge", a.purgeKey)
// r.Get("/cache", a.getCacheContents)
// r.Put("/rules", a.createRule)
// r.Get("/rules", a.getRules)
// r.Delete("/rules/{id}", a.deleteRole)
// r.Put("/rules/lists", a.addRulelist)
// r.Get("/rules/lists", a.getRuleLists)
// r.Delete("/rules/lists/{id}", a.deleteRuleList)
// r.Post("/rules/lists/reload", a.reloadRuleLists)
})
return a
}
func (a *adminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
a.h.ServeHTTP(w, r)
}
func (a *adminHandler) getStats(r *http.Request) (*RestResponse, error) {
q := r.URL.Query()
startFilter := q.Get("start")
endFilter := q.Get("end")
key := q.Get("key")
intervalSecondsStr := q.Get("interval")
var err error
start := time.Now().UTC().Add(time.Hour * -86400)
end := time.Now().UTC()
if startFilter != "" {
if start, err = time.Parse(ISO8601, startFilter); err != nil {
return nil, err
}
}
if endFilter != "" {
if end, err = time.Parse(ISO8601, endFilter); err != nil {
return nil, err
}
}
lai := LogAggregateInput{
Start: start,
End: end,
Column: key,
}
if intervalSecondsStr != "" {
if lai.IntervalSeconds, err = strconv.Atoi(intervalSecondsStr); err != nil {
return nil, errors.New("interval query param must be a valid whole number")
}
}
la, err := a.Storage.GetLogAggregate(lai)
if err != nil {
return nil, err
}
return &RestResponse{
Status: http.StatusOK,
Payload: struct {
Success bool `json:"success"`
Payload interface{} `json:"payload"`
}{
Success: true,
Payload: la,
},
}, nil
}
func (a *adminHandler) getLog(r *http.Request) (*RestResponse, error) {
q := r.URL.Query()
startFilter := q.Get("start")
endFilter := q.Get("end")
// filter := q.Get("filter")
pageStr := q.Get("page")
var err error
var page int
start := time.Now().UTC().Add(time.Hour * -86400)
end := time.Now().UTC()
if startFilter != "" {
if start, err = time.Parse(ISO8601, startFilter); err != nil {
return nil, err
}
}
if endFilter != "" {
if end, err = time.Parse(ISO8601, endFilter); err != nil {
return nil, err
}
}
if pageStr != "" {
page, _ = strconv.Atoi(pageStr)
}
ql, err := a.Storage.GetLog(GetLogInput{
Start: start,
End: end,
Limit: 250,
Page: page,
})
if err != nil {
return nil, err
}
return &RestResponse{
Status: http.StatusOK,
Payload: struct {
Success bool `json:"success"`
Payload interface{} `json:"payload"`
}{
Success: true,
Payload: ql,
},
}, nil
}
type RestResponse struct {
Status int
Headers http.Header
Payload struct {
Success bool `json:"success"`
Payload interface{} `json:"payload"`
}
}
func (rr *RestResponse) Write(w http.ResponseWriter) error {
if rr.Status != 0 && rr.Status != 200 {
w.WriteHeader(rr.Status)
}
for k, v := range rr.Headers {
for _, ve := range v {
w.Header().Add(k, ve)
}
}
e := json.NewEncoder(w)
e.SetIndent("\n", "\t")
if err := e.Encode(rr.Payload); err != nil {
return fmt.Errorf("could not serialize struct for http response: %w", err)
}
return nil
}
type RestHandler func(request *http.Request) (*RestResponse, error)
func (rh RestHandler) ToHF() http.HandlerFunc {
return func(rw http.ResponseWriter, r *http.Request) { rh.ServeHTTP(rw, r) }
}
func (rh RestHandler) Error(e error) *RestResponse {
return &RestResponse{
Status: http.StatusInternalServerError,
Payload: struct {
Success bool `json:"success"`
Payload interface{} `json:"payload"`
}{
Success: false,
Payload: e.Error(),
},
}
}
func (r RestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
response, err := r(req)
if err != nil {
response = r.Error(err)
}
if err := response.Write(w); err != nil {
log.Printf("Error occurred handling rest response: %v", err)
}
}

@ -19,6 +19,7 @@ type Storage interface {
// AddRule(Rule) error
Log(QueryLog) error
GetLog(GetLogInput) ([]QueryLog, error)
GetLogAggregate(LogAggregateInput) ([]LogAggregateDataPoint, error)
}
type Sqlite struct {
@ -30,8 +31,8 @@ type GetLogInput struct {
Start time.Time
End time.Time
DomainFilter string
Limit uint
Page uint
Limit int
Page int
}
func (ss *Sqlite) GetLog(in GetLogInput) ([]QueryLog, error) {
@ -40,7 +41,7 @@ func (ss *Sqlite) GetLog(in GetLogInput) ([]QueryLog, error) {
}
if in.Start.IsZero() {
in.Start = time.Now().UTC().Add(time.Hour * -24)
in.Start = time.Now().UTC().Add(time.Hour * -86400)
}
if in.End.IsZero() {
@ -54,7 +55,7 @@ func (ss *Sqlite) GetLog(in GetLogInput) ([]QueryLog, error) {
log
WHERE
id > ? AND started > ? AND started < ?
ORDER BY started
ORDER BY started DESC
LIMIT ?;
`
@ -97,6 +98,89 @@ func (ss *Sqlite) GetLog(in GetLogInput) ([]QueryLog, error) {
return ql, nil
}
type LogAggregateColumn string
var (
Domain = LogAggregateColumn("domain")
ClientIP = LogAggregateColumn("clientIp")
RecurseIP = LogAggregateColumn("recurseUpStreamIP")
Protocol = LogAggregateColumn("protocol")
Status = LogAggregateColumn("status")
AggregateKeys = map[string]LogAggregateColumn{
"domain": Domain,
"clientIp": ClientIP,
"recurseUpStreamIP": RecurseIP,
"protocol": Protocol,
"status": Status,
}
)
type LogAggregateInput struct {
IntervalSeconds int
Start time.Time
End time.Time
Column string
}
type LogAggregateDataPoint struct {
Header string
AverageTotalTime float64
Count int
Time time.Time
}
func (ss *Sqlite) GetLogAggregate(la LogAggregateInput) ([]LogAggregateDataPoint, error) {
timeWindow := int64(5 * 60)
column := "domain"
if lac, ok := AggregateKeys[la.Column]; ok {
column = string(lac)
}
if la.IntervalSeconds > 0 {
timeWindow = int64(la.IntervalSeconds)
}
sql := `
SELECT
%s,
ROUND(AVG(totalTimeMs), 3) as averageTotalTime,
COUNT(*) as requests,
strftime('%%s', started)/(%d) as "timeWindow"
FROM log
GROUP BY %s, strftime('%%s', started) / (%d)
ORDER BY started desc;
`
sql = fmt.Sprintf(sql, column, timeWindow, column, timeWindow)
result, err := ss.Query(sql)
if err != nil {
return nil, err
}
var results []LogAggregateDataPoint
for result.Next() {
var ladp LogAggregateDataPoint
var timeInterval int64
if err := result.Scan(
&ladp.Header,
&ladp.AverageTotalTime,
&ladp.Count,
&timeInterval,
); err != nil {
return nil, err
}
ladp.Time = time.Unix(timeInterval*timeWindow, 0)
results = append(results, ladp)
}
return results, nil
}
func (ss *Sqlite) Log(ql QueryLog) error {
sql := `
INSERT INTO log

@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"log"
"net/http"
"regexp"
"strings"
"time"
@ -37,6 +38,8 @@ func main() {
Rules: conf.Rules,
}
cache := &internal.Memory{}
dnsClient := &dns.Client{
Net: "udp",
DialTimeout: time.Millisecond * 250,
@ -45,7 +48,7 @@ func main() {
}
dm := &internal.DomainManager{
Cache: &internal.Memory{},
Cache: cache,
Storage: store,
RuleEvaluator: re,
Recursors: internal.Recursor{
@ -54,9 +57,15 @@ func main() {
},
}
dnsSrv := &internal.DNSServer{Handler: dm}
go func() {
dnsSrv := &internal.DNSServer{Handler: dm}
if err := dnsSrv.ListenAndServe(context.Background(), conf.DNSAddr); err != nil {
log.Fatal(err)
}
}()
if err := dnsSrv.ListenAndServe(context.Background(), conf.DNSAddr); err != nil {
httpApi := internal.NewAdminHandler(cache, store, re)
if err := http.ListenAndServe(conf.HTTPAddr, httpApi); err != nil {
log.Fatal(err)
}
}

@ -2,9 +2,18 @@ dev: clean .bin/gopherhole .bin/config.json
cd .bin && ./gopherhole -config config.json
clean:
@rm -rf .bin/gopherhole
clobber:
@rm -rf .bin
.PHONY: clean dev
test:
dig -p 5353 twitter.com @localhost
dig -p 5353 google.com @localhost
dig -p 5353 loki.veldhousen.ninja @localhost
dig -p 5353 www.liveauctioneers.com @localhost
.PHONY: clean clobber dev
.bin:
mkdir -p .bin

Loading…
Cancel
Save