package internal import ( "database/sql" "fmt" "io" "net" "strconv" "strings" "time" _ "github.com/mattn/go-sqlite3" ) const ISO8601 = "2006-01-02 15:04:05.999" type Storage interface { io.Closer Open() error AddRecursors(net.IP, int, int, int) error GetRecursors() ([]RecursorRow, error) UpdateRecursor(int, RecursorRow) error DeleteRecursors(int) error AddRule(RuleRow) error GetRule(int) (RuleRow, error) GetRules() ([]RuleRow, error) UpdateRule(int, RuleRow) error DeleteRule(int) error Log(QueryLog) error GetLog(GetLogInput) (GetLogResult, error) GetLogAggregate(LogAggregateInput) ([]LogAggregateDataPoint, error) } type Sqlite struct { Path string *sql.DB } func (ss *Sqlite) GetRecursors() ([]RecursorRow, error) { sql := ` SELECT id, ipAddress, timeoutMs, weight FROM recursors ORDER BY weight ASC; ` rows, err := ss.Query(sql) if err != nil { return nil, fmt.Errorf("could not execute select for recursors: %w", err) } defer rows.Close() results := []RecursorRow{} for rows.Next() { var row RecursorRow if err := rows.Scan(&row.ID, &row.IpAddress, &row.TimeoutMs, &row.Weight); err != nil { return nil, fmt.Errorf("could not read row: %w", err) } results = append(results, row) } return results, nil } func (ss *Sqlite) DeleteRecursors(id int) error { sql := `DELETE FROM recursors WHERE id = ?;` if _, err := ss.Exec(sql, id); err != nil { return fmt.Errorf("Could not delete recursor: %w", err) } return nil } type RecursorRow struct { ID int `json:"id"` IpAddress string `json:"ipAddress"` TimeoutMs int `json:"timeoutMs"` Weight int `json:"weight"` } func (rr RecursorRow) ValidIp() (net.IP, int, bool) { ipAddrFrags := strings.Split(rr.IpAddress, ":") if len(ipAddrFrags) == 0 || len(ipAddrFrags) > 2 { return nil, -1, false } var parsedIp net.IP parsedPort := 53 if parsedIp = net.ParseIP(ipAddrFrags[0]); parsedIp == nil { return nil, -1, false } if len(ipAddrFrags) > 1 { parsedPort, _ = strconv.Atoi(ipAddrFrags[1]) } return parsedIp, parsedPort, true } func (ss *Sqlite) UpdateRecursor(id int, in RecursorRow) error { sql := `UPDATE recursors SET ipAddress = ?, timeoutMs = ?, weight = ? WHERE id = ?;` if _, err := ss.Exec(sql, in.IpAddress, in.TimeoutMs, in.Weight); err != nil { return fmt.Errorf("could not update recursor: %w", err) } return nil } func (ss *Sqlite) AddRecursors(ip net.IP, port, timeout, weight int) error { sql := `INSERT INTO recursors (ipAddress, timeoutMs, weight) VALUES (?, ?, ?);` if _, err := ss.Exec(sql, fmt.Sprintf("%s:%d", ip.String(), port), timeout, weight); err != nil { return fmt.Errorf("could not insert recursor: %w", err) } return nil } type GetLogInput struct { Start time.Time `json:"start"` End time.Time `json:"end"` DomainFilter string `json:"rawfilter"` Limit int `json:"pageSize"` Page int `json:"page"` } type RuleRow struct { ID int `json:"id"` Weight int `json:"weight"` Enabled bool `json:"enabled"` Created time.Time `json:"created"` Rule } func (ss *Sqlite) UpdateRule(id int, in RuleRow) error { sql := `UPDATE rules SET name = ?, expression = ?, answerType = ?, answerValue = ?, ttl = ?, weight = ?, enabled = ? WHERE id = ?;` if _, err := ss.Exec(sql, in.Name, in.Value, in.Answer.Type, in.Answer.Value, in.TTL, in.Weight, in.Enabled); err != nil { return fmt.Errorf("could not update rule with id: %v", id) } return nil } func (ss *Sqlite) AddRule(rr RuleRow) error { sql := `INSERT INTO rules (name, expression, answerType, answerValue, ttl, weight, enabled, created) VALUES (?, ? , ?, ?, ?, ?, 1, ?);` if _, err := ss.Exec(sql, rr.Name, rr.Value, rr.Answer.Type, rr.Answer.Value, rr.TTL, rr.Weight, time.Now().UTC().Format(ISO8601)); err != nil { return fmt.Errorf("could not delete rule: %w", err) } return nil } func (ss *Sqlite) GetRule(ruleId int) (RuleRow, error) { sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules WHERE id = ?;` var rr RuleRow row := ss.QueryRow(sql, ruleId) var createdTime string if err := row.Scan(&rr.ID, &rr.Name, &rr.Value, &rr.Answer.Type, &rr.Answer.Value, &rr.TTL, &rr.Weight, &rr.Enabled, &createdTime); err != nil { return rr, err } var err error rr.Created, err = time.Parse(ISO8601, createdTime) if err != nil { return rr, fmt.Errorf("could not parse time: %w", err) } return rr, nil } func (ss *Sqlite) DeleteRule(ruleId int) error { if _, err := ss.Exec(`DELETE FROM rules WHERE id = ?;`, ruleId); err != nil { return fmt.Errorf("could not delete rule: %w", err) } return nil } func (ss *Sqlite) GetRules() ([]RuleRow, error) { sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules ORDER BY weight ASC;` rows, err := ss.Query(sql) if err != nil { return nil, err } defer rows.Close() if rerr := rows.Err(); rerr != nil { return nil, err } results := []RuleRow{} for rows.Next() { var rule RuleRow var createdTime string if err := rows.Scan( &rule.ID, &rule.Name, &rule.Value, &rule.Answer.Type, &rule.Answer.Value, &rule.TTL, &rule.Weight, &rule.Enabled, &createdTime, ); err != nil { return nil, fmt.Errorf("could not read from db: %w", err) } rule.Created, err = time.Parse(ISO8601, createdTime) if err != nil { return nil, err } results = append(results, rule) } return results, nil } type GetLogResult struct { GetLogInput TotalResults int `json:"total"` PageCount int `json:"pageCount"` Logs []QueryLog `json:"logs"` } func (ss *Sqlite) GetLog(in GetLogInput) (GetLogResult, error) { if in.Limit <= 0 { in.Limit = 100 } if in.Start.IsZero() { in.Start = time.Now().Add(time.Hour * -86400) } if in.End.IsZero() { in.End = time.Now() } glr := GetLogResult{ GetLogInput: in, Logs: []QueryLog{}, } sql := ` SELECT started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status FROM log WHERE id > ? AND strftime('%s', started) > strftime('%s', ?) AND strftime('%s', started) < strftime('%s', ?) ORDER BY started DESC LIMIT ?; ` rows, err := ss.DB.Query(sql, in.Page*in.Limit, in.Start.UTC().Format(ISO8601), in.End.UTC().Format(ISO8601), in.Limit) if err != nil { return glr, fmt.Errorf("issue with GetLog sql query: %w", err) } defer rows.Close() if rerr := rows.Err(); rerr != nil { return glr, fmt.Errorf("issue with rows object: %w", rerr) } for rows.Next() { var q QueryLog var started string if err := rows.Scan( &started, &q.ClientIP, &q.Protocol, &q.Domain, &q.TotalTimeMs, &q.Error, &q.RecurseRoundTripTimeMs, &q.RecurseUpstreamIP, &q.Status, ); err != nil { return glr, fmt.Errorf("issues scanning rows: %w", err) } if q.Started, err = time.Parse(ISO8601, started); err != nil { return glr, fmt.Errorf("could not parse time '%s': %w", started, err) } glr.Logs = append(glr.Logs, q) } total, pageCount, err := ss.GetPagingInfo(in) if err != nil { return glr, err } glr.TotalResults = total glr.PageCount = pageCount return glr, nil } func (ss *Sqlite) GetPagingInfo(in GetLogInput) (totalItems, pageCount int, err error) { sql := ` SELECT COUNT(*) as totalLogsEntries, COUNT(*) / ? as pageCount FROM log WHERE strftime('%s', started) > strftime('%s', ?) AND strftime('%s', started) < strftime('%s', ?) ` row := ss.QueryRow(sql, in.Limit, in.Start.UTC().Format(ISO8601), in.End.UTC().Format(ISO8601)) if err = row.Scan(&totalItems, &pageCount); err != nil { return } return } type LogAggregateColumn string var ( Domain = LogAggregateColumn("domain") ClientIP = LogAggregateColumn("clientIp") RecurseIP = LogAggregateColumn("recurseUpStreamIP") Protocol = LogAggregateColumn("protocol") Status = LogAggregateColumn("status") AggregateKeys = map[string]LogAggregateColumn{ "domain": Domain, "clientIp": ClientIP, "recurseUpStreamIP": RecurseIP, "protocol": Protocol, "status": Status, } ) type LogAggregateInput struct { IntervalSeconds int Start time.Time End time.Time Column string } type LogAggregateDataPoint struct { Header string AverageTotalTime float64 Count int Time time.Time } func (ss *Sqlite) GetLogAggregate(la LogAggregateInput) ([]LogAggregateDataPoint, error) { timeWindow := int64(5 * 60) column := "domain" if lac, ok := AggregateKeys[la.Column]; ok { column = string(lac) } if la.IntervalSeconds > 0 { timeWindow = int64(la.IntervalSeconds) } sql := ` SELECT %s, ROUND(AVG(totalTimeMs), 3) as averageTotalTime, COUNT(*) as requests, strftime('%%s', started)/(%d) as "timeWindow" FROM log GROUP BY %s, strftime('%%s', started) / (%d) ORDER BY started ASC; ` sql = fmt.Sprintf(sql, column, timeWindow, column, timeWindow) rows, err := ss.Query(sql) if err != nil { return nil, err } defer rows.Close() if err := rows.Err(); err != nil { return nil, err } var results []LogAggregateDataPoint for rows.Next() { var ladp LogAggregateDataPoint var timeInterval int64 if err := rows.Scan( &ladp.Header, &ladp.AverageTotalTime, &ladp.Count, &timeInterval, ); err != nil { return nil, err } ladp.Time = time.Unix(timeInterval*timeWindow, 0) results = append(results, ladp) } return results, nil } func (ss *Sqlite) Log(ql QueryLog) error { sql := ` INSERT INTO log (started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); ` if _, err := ss.DB.Exec(sql, ql.Started.UTC().Format(ISO8601), ql.ClientIP, ql.Protocol, ql.Domain, ql.TotalTimeMs, ql.Error, ql.RecurseRoundTripTimeMs, ql.RecurseUpstreamIP, ql.Status, ); err != nil { return err } return nil } func (ss *Sqlite) Open() error { db, err := sql.Open("sqlite3", fmt.Sprintf("%s?cache=shared&_journal=WAL", ss.Path)) if err != nil { return fmt.Errorf("could not open db: %w", err) } db.SetMaxOpenConns(1) ss.DB = db if err := initTable(db); err != nil { return err } return nil } func initTable(db *sql.DB) error { sql := ` CREATE TABLE IF NOT EXISTS log ( id INTEGER PRIMARY KEY, started TEXT NOT NULL, clientIp TEXT NOT NULL, protocol TEXT NOT NULL, domain TEXT NOT NULL, totalTimeMs int NOT NULL, error TEXT, recurseRoundTripTimeMs INT, recurseUpStreamIP TEXT, status TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS rules ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, expression TEXT NOT NULL, answerType TEXT NOT NULL, answerValue TEXT NOT NULL, ttl INT NOT NULL, weight INT NOT NULL, enabled INT NOT NULL, created TEXT NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_name ON rules (name); CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_expression ON rules (expression); CREATE TABLE IF NOT EXISTS recursors ( id INTEGER PRIMARY KEY, ipAddress TEXT NOT NULL, timeoutMs INT NOT NULL, weight INT NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS idx_recursors_ipAddress ON recursors (ipAddress); CREATE TABLE IF NOT EXISTS ruleslist ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, url TEXT NOT NULL, lastLoadedTs TEXT NOT NULL, weight INT NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_name ON ruleslist (name); CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_url ON ruleslist (url); CREATE TABLE IF NOT EXISTS ruleslist_entry ( id INTEGER PRIMARY KEY, ruleslistId INTEGER NOT NULL, expression TEXT NOT NULL, ipAddress TEXT NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_entry_expression ON ruleslist_entry (expression); CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_entry_ruleslistId ON ruleslist_entry (id, ruleslistId); ` if _, err := db.Exec(sql); err != nil { return fmt.Errorf("could not initialize db: %w", err) } return nil }