You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

431 lines
9.4 KiB

package internal
import (
"database/sql"
"fmt"
"io"
"net"
"time"
_ "github.com/mattn/go-sqlite3"
)
const ISO8601 = "2006-01-02 15:04:05.999"
type Storage interface {
io.Closer
Open() error
AddRecursors(net.IP, int, int, int) error
DeleteRecursors(int) error
GetRecursors() ([]RecursorRow, error)
// UpdateRule(RuleRow) error
AddRule(RuleRow) error
GetRule(int) (RuleRow, error)
GetRules() ([]RuleRow, error)
DeleteRule(int) error
Log(QueryLog) error
GetLog(GetLogInput) ([]QueryLog, error)
GetLogAggregate(LogAggregateInput) ([]LogAggregateDataPoint, error)
}
type Sqlite struct {
Path string
*sql.DB
}
func (ss *Sqlite) GetRecursors() ([]RecursorRow, error) {
sql := `
SELECT id, ipAddress, timeoutMs, weight FROM recursors ORDER BY weight ASC;
`
rows, err := ss.Query(sql)
if err != nil {
return nil, fmt.Errorf("could not execute select for recursors: %w", err)
}
defer rows.Close()
var results []RecursorRow
for rows.Next() {
var row RecursorRow
if err := rows.Scan(&row.ID, &row.IpAddress, &row.TimeoutMs, &row.Weight); err != nil {
return nil, fmt.Errorf("could not read row: %w", err)
}
results = append(results, row)
}
return results, nil
}
func (ss *Sqlite) DeleteRecursors(id int) error {
sql := `DELETE FROM recursors WHERE id = ?;`
if _, err := ss.Exec(sql, id); err != nil {
return fmt.Errorf("Could not delete recursor: %w", err)
}
return nil
}
type RecursorRow struct {
ID int `json:"id"`
IpAddress string `json:"ipAddress"`
TimeoutMs int `json:"timeoutMs"`
Weight int `json:"weight"`
}
func (ss *Sqlite) AddRecursors(ip net.IP, port, timeout, weight int) error {
sql := `INSERT INTO recursors (ipAddress, timeoutMs, weight) VALUES (?, ?, ?);`
if _, err := ss.Exec(sql, fmt.Sprintf("%s:%d", ip.String(), port), timeout, weight); err != nil {
return fmt.Errorf("could not insert recursor: %w", err)
}
return nil
}
type GetLogInput struct {
Start time.Time
End time.Time
DomainFilter string
Limit int
Page int
}
type RuleRow struct {
ID int `json:"id"`
Weight int `json:"weight"`
Enabled bool `json:"enabled"`
Created time.Time `json:"created"`
Rule
}
func (ss *Sqlite) AddRule(rr RuleRow) error {
sql := `INSERT INTO rules (name, expression, answerType, answerValue, ttl, weight, enabled, created)
VALUES (?, ? , ?, ?, ?, ?, 1, ?);`
if _, err := ss.Exec(sql, rr.Name, rr.Value, rr.Answer.Type, rr.Answer.Value, rr.TTL, rr.Weight, time.Now().UTC().Format(ISO8601)); err != nil {
return fmt.Errorf("could not delete rule: %w", err)
}
return nil
}
func (ss *Sqlite) GetRule(ruleId int) (RuleRow, error) {
sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules WHERE id = ?;`
var rr RuleRow
row := ss.QueryRow(sql, ruleId)
var createdTime string
if err := row.Scan(&rr.ID, &rr.Name, &rr.Value, &rr.Answer.Type, &rr.Answer.Value, &rr.TTL, &rr.Weight, &rr.Enabled, &createdTime); err != nil {
return rr, err
}
var err error
rr.Created, err = time.Parse(ISO8601, createdTime)
if err != nil {
return rr, fmt.Errorf("could not parse time: %w", err)
}
return rr, nil
}
func (ss *Sqlite) DeleteRule(ruleId int) error {
if _, err := ss.Exec(`DELETE FROM rules WHERE id = ?;`, ruleId); err != nil {
return fmt.Errorf("could not delete rule: %w", err)
}
return nil
}
func (ss *Sqlite) GetRules() ([]RuleRow, error) {
sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules ORDER BY weight ASC;`
rows, err := ss.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
var results []RuleRow
for rows.Next() {
var rule RuleRow
var createdTime string
if err := rows.Scan(
&rule.ID,
&rule.Name,
&rule.Value,
&rule.Answer.Type,
&rule.Answer.Value,
&rule.TTL,
&rule.Weight,
&rule.Enabled,
&createdTime,
); err != nil {
return nil, fmt.Errorf("could not read from db: %w", err)
}
rule.Created, err = time.Parse(ISO8601, createdTime)
if err != nil {
return nil, err
}
results = append(results, rule)
}
return results, nil
}
func (ss *Sqlite) GetLog(in GetLogInput) ([]QueryLog, error) {
if in.Limit <= 0 {
in.Limit = 100
}
if in.Start.IsZero() {
in.Start = time.Now().UTC().Add(time.Hour * -86400)
}
if in.End.IsZero() {
in.End = time.Now().UTC()
}
sql := `
SELECT
started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status
FROM
log
WHERE
id > ? AND started > ? AND started < ?
ORDER BY started DESC
LIMIT ?;
`
rows, err := ss.DB.Query(sql, in.Page*in.Limit, in.Start.Format(ISO8601), in.End.Format(ISO8601), in.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var ql []QueryLog
for rows.Next() {
var q QueryLog
var started string
rows.Scan(
&started,
&q.ClientIP,
&q.Protocol,
&q.Domain,
&q.TotalTimeMs,
&q.Error,
&q.RecurseRoundTripTimeMs,
&q.RecurseUpstreamIP,
&q.Status,
)
if q.Started, err = time.Parse(ISO8601, started); err != nil {
return nil, fmt.Errorf("could not parse time '%s': %v", started, err)
}
ql = append(ql, q)
}
return ql, nil
}
type LogAggregateColumn string
var (
Domain = LogAggregateColumn("domain")
ClientIP = LogAggregateColumn("clientIp")
RecurseIP = LogAggregateColumn("recurseUpStreamIP")
Protocol = LogAggregateColumn("protocol")
Status = LogAggregateColumn("status")
AggregateKeys = map[string]LogAggregateColumn{
"domain": Domain,
"clientIp": ClientIP,
"recurseUpStreamIP": RecurseIP,
"protocol": Protocol,
"status": Status,
}
)
type LogAggregateInput struct {
IntervalSeconds int
Start time.Time
End time.Time
Column string
}
type LogAggregateDataPoint struct {
Header string
AverageTotalTime float64
Count int
Time time.Time
}
func (ss *Sqlite) GetLogAggregate(la LogAggregateInput) ([]LogAggregateDataPoint, error) {
timeWindow := int64(5 * 60)
column := "domain"
if lac, ok := AggregateKeys[la.Column]; ok {
column = string(lac)
}
if la.IntervalSeconds > 0 {
timeWindow = int64(la.IntervalSeconds)
}
sql := `
SELECT
%s,
ROUND(AVG(totalTimeMs), 3) as averageTotalTime,
COUNT(*) as requests,
strftime('%%s', started)/(%d) as "timeWindow"
FROM log
GROUP BY %s, strftime('%%s', started) / (%d)
ORDER BY started ASC;
`
sql = fmt.Sprintf(sql, column, timeWindow, column, timeWindow)
rows, err := ss.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
var results []LogAggregateDataPoint
for rows.Next() {
var ladp LogAggregateDataPoint
var timeInterval int64
if err := rows.Scan(
&ladp.Header,
&ladp.AverageTotalTime,
&ladp.Count,
&timeInterval,
); err != nil {
return nil, err
}
ladp.Time = time.Unix(timeInterval*timeWindow, 0)
results = append(results, ladp)
}
return results, nil
}
func (ss *Sqlite) Log(ql QueryLog) error {
sql := `
INSERT INTO log
(started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?);
`
if _, err := ss.DB.Exec(sql,
ql.Started.Format(ISO8601),
ql.ClientIP,
ql.Protocol,
ql.Domain,
ql.TotalTimeMs,
ql.Error,
ql.RecurseRoundTripTimeMs,
ql.RecurseUpstreamIP,
ql.Status,
); err != nil {
return err
}
return nil
}
func (ss *Sqlite) Open() error {
db, err := sql.Open("sqlite3", fmt.Sprintf("%s?cache=shared", ss.Path))
if err != nil {
return fmt.Errorf("could not open db: %w", err)
}
db.SetMaxOpenConns(1)
ss.DB = db
if err := initTable(db); err != nil {
return err
}
return nil
}
func initTable(db *sql.DB) error {
sql := `
CREATE TABLE IF NOT EXISTS log (
id INTEGER PRIMARY KEY,
started TEXT NOT NULL,
clientIp TEXT NOT NULL,
protocol TEXT NOT NULL,
domain TEXT NOT NULL,
totalTimeMs int NOT NULL,
error TEXT,
recurseRoundTripTimeMs INT,
recurseUpStreamIP TEXT,
status TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS rules (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
expression TEXT NOT NULL,
answerType TEXT NOT NULL,
answerValue TEXT NOT NULL,
ttl INT NOT NULL,
weight INT NOT NULL,
enabled INT NOT NULL,
created TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_name ON rules (name);
CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_expression ON rules (expression);
CREATE TABLE IF NOT EXISTS recursors (
id INTEGER PRIMARY KEY,
ipAddress TEXT NOT NULL,
timeoutMs INT NOT NULL,
weight INT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_recursors_ipAddress ON recursors (ipAddress);
CREATE TABLE IF NOT EXISTS ruleslist (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
url TEXT NOT NULL,
lastLoadedTs TEXT NOT NULL,
weight INT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_name ON ruleslist (name);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_url ON ruleslist (url);
CREATE TABLE IF NOT EXISTS ruleslist_entry (
id INTEGER PRIMARY KEY,
ruleslistId INTEGER NOT NULL,
expression TEXT NOT NULL,
ipAddress TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_entry_expression ON ruleslist_entry (expression);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_entry_ruleslistId ON ruleslist_entry (id, ruleslistId);
`
if _, err := db.Exec(sql); err != nil {
return fmt.Errorf("could not initialize db: %w", err)
}
return nil
}