You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
418 lines
8.9 KiB
418 lines
8.9 KiB
package internal
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
const ISO8601 = "2006-01-02 15:04:05.999"
|
|
|
|
type Storage interface {
|
|
io.Closer
|
|
Open() error
|
|
AddRecursors(net.IP, int, int, int) error
|
|
DeleteRecursors(int) error
|
|
GetRecursors() ([]RecursorRow, error)
|
|
// UpdateRule(RuleRow) error
|
|
AddRule(RuleRow) error
|
|
GetRule(int) (RuleRow, error)
|
|
GetRules() ([]RuleRow, error)
|
|
DeleteRule(int) error
|
|
Log(QueryLog) error
|
|
GetLog(GetLogInput) ([]QueryLog, error)
|
|
GetLogAggregate(LogAggregateInput) ([]LogAggregateDataPoint, error)
|
|
}
|
|
|
|
type Sqlite struct {
|
|
Path string
|
|
*sql.DB
|
|
}
|
|
|
|
func (ss *Sqlite) GetRecursors() ([]RecursorRow, error) {
|
|
sql := `
|
|
SELECT id, ipAddress, timeoutMs, weight FROM recursors ORDER BY weight ASC;
|
|
`
|
|
|
|
rows, err := ss.Query(sql)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not execute select for recursors: %w", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
var results []RecursorRow
|
|
for rows.Next() {
|
|
var row RecursorRow
|
|
|
|
if err := rows.Scan(&row.ID, &row.IpAddress, &row.TimeoutMs, &row.Weight); err != nil {
|
|
return nil, fmt.Errorf("could not read row: %w", err)
|
|
}
|
|
|
|
results = append(results, row)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (ss *Sqlite) DeleteRecursors(id int) error {
|
|
sql := `DELETE FROM recursors WHERE id = ?;`
|
|
if _, err := ss.Exec(sql, id); err != nil {
|
|
return fmt.Errorf("Could not delete recursor: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type RecursorRow struct {
|
|
ID int `json:"id"`
|
|
IpAddress string `json:"ipAddress"`
|
|
TimeoutMs int `json:"timeoutMs"`
|
|
Weight int `json:"weight"`
|
|
}
|
|
|
|
func (ss *Sqlite) AddRecursors(ip net.IP, port, timeout, weight int) error {
|
|
sql := `INSERT INTO recursors (ipAddress, timeoutMs, weight) VALUES (?, ?, ?);`
|
|
|
|
if _, err := ss.Exec(sql, fmt.Sprintf("%s:%d", ip.String(), port), timeout, weight); err != nil {
|
|
return fmt.Errorf("could not insert recursor: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type GetLogInput struct {
|
|
Start time.Time
|
|
End time.Time
|
|
DomainFilter string
|
|
Limit int
|
|
Page int
|
|
}
|
|
|
|
type RuleRow struct {
|
|
ID int `json:"id"`
|
|
Weight int `json:"weight"`
|
|
Enabled bool `json:"enabled"`
|
|
Created time.Time `json:"created"`
|
|
Rule
|
|
}
|
|
|
|
func (ss *Sqlite) AddRule(rr RuleRow) error {
|
|
sql := `INSERT INTO rules (name, expression, answerType, answerValue, ttl, weight, enabled, created)
|
|
VALUES (?, ? , ?, ?, ?, ?, 1, ?);`
|
|
|
|
if _, err := ss.Exec(sql, rr.Name, rr.Value, rr.Answer.Type, rr.Answer.Value, rr.TTL, rr.Weight, time.Now().UTC().Format(ISO8601)); err != nil {
|
|
return fmt.Errorf("could not delete rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ss *Sqlite) GetRule(ruleId int) (RuleRow, error) {
|
|
sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules WHERE id = ?;`
|
|
|
|
var rr RuleRow
|
|
row := ss.QueryRow(sql, ruleId)
|
|
|
|
var createdTime string
|
|
if err := row.Scan(&rr.ID, &rr.Name, &rr.Value, &rr.Answer.Type, &rr.Answer.Value, &rr.TTL, &rr.Weight, &rr.Enabled, &createdTime); err != nil {
|
|
return rr, err
|
|
}
|
|
|
|
var err error
|
|
rr.Created, err = time.Parse(ISO8601, createdTime)
|
|
if err != nil {
|
|
return rr, fmt.Errorf("could not parse time: %w", err)
|
|
}
|
|
|
|
return rr, nil
|
|
}
|
|
|
|
func (ss *Sqlite) DeleteRule(ruleId int) error {
|
|
if _, err := ss.Exec(`DELETE FROM rules WHERE id = ?;`, ruleId); err != nil {
|
|
return fmt.Errorf("could not delete rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ss *Sqlite) GetRules() ([]RuleRow, error) {
|
|
sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules ORDER BY weight ASC;`
|
|
|
|
rows, err := ss.Query(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []RuleRow
|
|
for rows.Next() {
|
|
var rule RuleRow
|
|
var createdTime string
|
|
|
|
if err := rows.Scan(
|
|
&rule.ID,
|
|
&rule.Name,
|
|
&rule.Value,
|
|
&rule.Answer.Type,
|
|
&rule.Answer.Value,
|
|
&rule.TTL,
|
|
&rule.Weight,
|
|
&rule.Enabled,
|
|
&createdTime,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("could not read from db: %w", err)
|
|
}
|
|
|
|
rule.Created, err = time.Parse(ISO8601, createdTime)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
results = append(results, rule)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (ss *Sqlite) GetLog(in GetLogInput) ([]QueryLog, error) {
|
|
if in.Limit <= 0 {
|
|
in.Limit = 100
|
|
}
|
|
|
|
if in.Start.IsZero() {
|
|
in.Start = time.Now().UTC().Add(time.Hour * -86400)
|
|
}
|
|
|
|
if in.End.IsZero() {
|
|
in.End = time.Now().UTC()
|
|
}
|
|
|
|
sql := `
|
|
SELECT
|
|
started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status
|
|
FROM
|
|
log
|
|
WHERE
|
|
id > ? AND started > ? AND started < ?
|
|
ORDER BY started DESC
|
|
LIMIT ?;
|
|
`
|
|
|
|
rows, err := ss.DB.Query(sql, in.Page*in.Limit, in.Start.Format(ISO8601), in.End.Format(ISO8601), in.Limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var ql []QueryLog
|
|
for rows.Next() {
|
|
var q QueryLog
|
|
var errStr string
|
|
var started string
|
|
|
|
rows.Scan(
|
|
&started,
|
|
&q.ClientIP,
|
|
&q.Protocol,
|
|
&q.Domain,
|
|
&q.TotalTimeMs,
|
|
&errStr,
|
|
&q.RecurseRoundTripTimeMs,
|
|
&q.RecurseUpstreamIP,
|
|
&q.Status,
|
|
)
|
|
|
|
if errStr != "" {
|
|
q.Error = errors.New(errStr)
|
|
}
|
|
|
|
if q.Started, err = time.Parse(ISO8601, started); err != nil {
|
|
return nil, fmt.Errorf("could not parse time '%s': %v", started, err)
|
|
}
|
|
|
|
ql = append(ql, q)
|
|
}
|
|
|
|
return ql, nil
|
|
}
|
|
|
|
type LogAggregateColumn string
|
|
|
|
var (
|
|
Domain = LogAggregateColumn("domain")
|
|
ClientIP = LogAggregateColumn("clientIp")
|
|
RecurseIP = LogAggregateColumn("recurseUpStreamIP")
|
|
Protocol = LogAggregateColumn("protocol")
|
|
Status = LogAggregateColumn("status")
|
|
|
|
AggregateKeys = map[string]LogAggregateColumn{
|
|
"domain": Domain,
|
|
"clientIp": ClientIP,
|
|
"recurseUpStreamIP": RecurseIP,
|
|
"protocol": Protocol,
|
|
"status": Status,
|
|
}
|
|
)
|
|
|
|
type LogAggregateInput struct {
|
|
IntervalSeconds int
|
|
Start time.Time
|
|
End time.Time
|
|
Column string
|
|
}
|
|
|
|
type LogAggregateDataPoint struct {
|
|
Header string
|
|
AverageTotalTime float64
|
|
Count int
|
|
Time time.Time
|
|
}
|
|
|
|
func (ss *Sqlite) GetLogAggregate(la LogAggregateInput) ([]LogAggregateDataPoint, error) {
|
|
timeWindow := int64(5 * 60)
|
|
column := "domain"
|
|
|
|
if lac, ok := AggregateKeys[la.Column]; ok {
|
|
column = string(lac)
|
|
}
|
|
|
|
if la.IntervalSeconds > 0 {
|
|
timeWindow = int64(la.IntervalSeconds)
|
|
}
|
|
|
|
sql := `
|
|
SELECT
|
|
%s,
|
|
ROUND(AVG(totalTimeMs), 3) as averageTotalTime,
|
|
COUNT(*) as requests,
|
|
strftime('%%s', started)/(%d) as "timeWindow"
|
|
FROM log
|
|
GROUP BY %s, strftime('%%s', started) / (%d)
|
|
ORDER BY started desc;
|
|
`
|
|
|
|
sql = fmt.Sprintf(sql, column, timeWindow, column, timeWindow)
|
|
|
|
rows, err := ss.Query(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []LogAggregateDataPoint
|
|
for rows.Next() {
|
|
var ladp LogAggregateDataPoint
|
|
var timeInterval int64
|
|
if err := rows.Scan(
|
|
&ladp.Header,
|
|
&ladp.AverageTotalTime,
|
|
&ladp.Count,
|
|
&timeInterval,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ladp.Time = time.Unix(timeInterval*timeWindow, 0)
|
|
|
|
results = append(results, ladp)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (ss *Sqlite) Log(ql QueryLog) error {
|
|
sql := `
|
|
INSERT INTO log
|
|
(started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status)
|
|
VALUES
|
|
(?, ?, ?, ?, ?, ?, ?, ?, ?);
|
|
`
|
|
var errStr string
|
|
if ql.Error != nil {
|
|
errStr = ql.Error.Error()
|
|
}
|
|
if _, err := ss.DB.Exec(sql,
|
|
ql.Started.Format(ISO8601),
|
|
ql.ClientIP,
|
|
ql.Protocol,
|
|
ql.Domain,
|
|
ql.TotalTimeMs,
|
|
errStr,
|
|
ql.RecurseRoundTripTimeMs,
|
|
ql.RecurseUpstreamIP,
|
|
ql.Status,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ss *Sqlite) Open() error {
|
|
db, err := sql.Open("sqlite3", fmt.Sprintf("%s?cache=shared", ss.Path))
|
|
if err != nil {
|
|
return fmt.Errorf("could not open db: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(1)
|
|
|
|
ss.DB = db
|
|
|
|
if err := initTable(db); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func initTable(db *sql.DB) error {
|
|
sql := `
|
|
CREATE TABLE IF NOT EXISTS log (
|
|
id INTEGER PRIMARY KEY,
|
|
started TEXT NOT NULL,
|
|
clientIp TEXT NOT NULL,
|
|
protocol TEXT NOT NULL,
|
|
domain TEXT NOT NULL,
|
|
totalTimeMs int NOT NULL,
|
|
error TEXT,
|
|
recurseRoundTripTimeMs INT,
|
|
recurseUpStreamIP TEXT,
|
|
status TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS rules (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
expression TEXT NOT NULL,
|
|
answerType TEXT NOT NULL,
|
|
answerValue TEXT NOT NULL,
|
|
ttl INT NOT NULL,
|
|
weight INT NOT NULL,
|
|
enabled INT NOT NULL,
|
|
created TEXT NOT NULL
|
|
);
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_name ON rules (name);
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_expression ON rules (expression);
|
|
|
|
CREATE TABLE IF NOT EXISTS recursors (
|
|
id INTEGER PRIMARY KEY,
|
|
ipAddress TEXT NOT NULL,
|
|
timeoutMs INT NOT NULL,
|
|
weight INT NOT NULL
|
|
);
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_recursors_ipAddress ON recursors (ipAddress);
|
|
`
|
|
|
|
if _, err := db.Exec(sql); err != nil {
|
|
return fmt.Errorf("could not initialize db: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|