You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

546 lines
12 KiB

package internal
import (
"database/sql"
"errors"
"fmt"
"log"
"net"
"strconv"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
const (
defaultSamples = 64
maxSamples = 128
)
type Sqlite struct {
Path string
*sql.DB
}
func (ss *Sqlite) Open() error {
db, err := sql.Open("sqlite3", fmt.Sprintf("%s/db.sqlite?cache=shared&_journal=WAL", ss.Path))
if err != nil {
return fmt.Errorf("could not open db: %w", err)
}
db.SetMaxOpenConns(1)
ss.DB = db
if err := initTable(db); err != nil {
return err
}
return nil
}
func (ss *Sqlite) Close() error {
ss.DB.Close()
return nil
}
func (ss *Sqlite) GetLogAggregate(la LogAggregateInput) (LogAggregate, error) {
if la.End.IsZero() || la.End.After(time.Now()) {
la.End = time.Now().UTC()
}
if la.Start.After(la.End) {
return LogAggregate{}, errors.New("Start time cannot be before end time")
}
if la.Start.IsZero() {
la.Start = time.Now().UTC().Add(time.Hour * -12)
}
timespanSecs := int(la.End.Sub(la.Start) / time.Second)
// how many data points to show on the line plot
sampleCount := defaultSamples
if la.IntervalSeconds <= 0 {
la.IntervalSeconds = timespanSecs / sampleCount
}
sampleCount = timespanSecs / la.IntervalSeconds
// cap to prevent performance issues
if sampleCount > maxSamples {
sampleCount = maxSamples
la.IntervalSeconds = timespanSecs / sampleCount
}
log.Printf("%+v - samples: %v - timespan (seconds): %v", la, sampleCount, timespanSecs)
switch la.Column {
case string(Domain):
case string(Status):
case string(ClientIP):
case string(Protocol):
break
default:
la.Column = string(Domain)
}
logs, err := ss.GetLog(GetLogInput{
Start: la.Start,
End: la.End,
Limit: 10000,
Page: 0,
})
if err != nil {
return LogAggregate{}, err
}
if logs.PageCount > 1 {
return LogAggregate{}, fmt.Errorf("more than one page available: %v", logs.PageCount)
}
buckets := map[string][]StatsDataPoint{}
for _, l := range logs.Logs {
k := GetAggregateColumnHeader(l, LogAggregateColumn(la.Column))
if _, ok := buckets[k]; !ok {
buckets[k] = make([]StatsDataPoint, sampleCount+1)
}
dataset := buckets[k]
timeIndex := int(l.Started.Sub(la.Start)/time.Second) / la.IntervalSeconds
ladp := dataset[timeIndex]
ladp.Header = k
offsetSecs := (timeIndex * la.IntervalSeconds)
ladp.Time = la.Start.Add(time.Duration(offsetSecs) * time.Second)
ladp.Count += 1
ladp.Value += float64(l.TotalTimeMs)
buckets[k][timeIndex] = ladp
}
laResult := LogAggregate{
Labels: make([]string, sampleCount),
Datasets: make([]LogAggregateDataset, len(buckets)),
}
for idx := 0; idx < sampleCount; idx++ {
offsetSecs := (idx * la.IntervalSeconds)
ts := la.Start.Add(time.Duration(offsetSecs) * time.Second)
laResult.Labels[idx] = ts.Format("01-02 15:04:05")
idx := 0
for k, v := range buckets {
ladp := v[idx]
if ladp.Time.IsZero() {
v[idx].Time = ts
}
laResult.Datasets[idx].Dataset = v
laResult.Datasets[idx].Label = k
idx++
}
}
return laResult, nil
}
type LogAggregate struct {
Labels []string `json:"labels"`
Datasets []LogAggregateDataset `json:"datasets"`
}
type LogAggregateDataset struct {
Label string `json:"label"`
Dataset []StatsDataPoint `json:"data"`
}
func (ss *Sqlite) Log(ql QueryLog) error {
sql := `
INSERT INTO log
(started, clientIp, protocol, domain, totalTimeMs, error, recurseRoundTripTimeMs, recurseUpstreamIp, status)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?);
`
if _, err := ss.DB.Exec(sql,
ql.Started.UTC().Format(ISO8601),
ql.ClientIP,
ql.Protocol,
ql.Domain,
ql.TotalTimeMs,
ql.Error,
ql.RecurseRoundTripTimeMs,
ql.RecurseUpstreamIP,
ql.Status,
); err != nil {
return err
}
return nil
}
func (ss *Sqlite) GetRecursors() ([]RecursorRow, error) {
sql := `
SELECT id, ipAddress, timeoutMs, weight FROM recursors ORDER BY weight ASC;
`
rows, err := ss.Query(sql)
if err != nil {
return nil, fmt.Errorf("could not execute select for recursors: %w", err)
}
defer rows.Close()
if err := rows.Err(); err != nil {
return nil, err
}
results := []RecursorRow{}
for rows.Next() {
var row RecursorRow
if err := rows.Scan(&row.ID, &row.IpAddress, &row.TimeoutMs, &row.Weight); err != nil {
return nil, fmt.Errorf("could not read row: %w", err)
}
results = append(results, row)
}
return results, nil
}
func (ss *Sqlite) DeleteRecursors(id int) error {
sql := `DELETE FROM recursors WHERE id = ?;`
if _, err := ss.Exec(sql, id); err != nil {
return fmt.Errorf("Could not delete recursor: %w", err)
}
return nil
}
func (rr RecursorRow) ValidIp() (net.IP, int, bool) {
ipAddrFrags := strings.Split(rr.IpAddress, ":")
if len(ipAddrFrags) == 0 || len(ipAddrFrags) > 2 {
return nil, -1, false
}
var parsedIp net.IP
parsedPort := 53
if parsedIp = net.ParseIP(ipAddrFrags[0]); parsedIp == nil {
return nil, -1, false
}
if len(ipAddrFrags) > 1 {
parsedPort, _ = strconv.Atoi(ipAddrFrags[1])
}
return parsedIp, parsedPort, true
}
func (ss *Sqlite) UpdateRecursor(id int, in RecursorRow) error {
sql := `UPDATE recursors SET ipAddress = ?, timeoutMs = ?, weight = ? WHERE id = ?;`
if _, err := ss.Exec(sql, in.IpAddress, in.TimeoutMs, in.Weight); err != nil {
return fmt.Errorf("could not update recursor: %w", err)
}
return nil
}
func (ss *Sqlite) AddRecursors(ip net.IP, port, timeout, weight int) error {
sql := `INSERT INTO recursors (ipAddress, timeoutMs, weight) VALUES (?, ?, ?);`
if _, err := ss.Exec(sql, fmt.Sprintf("%s:%d", ip.String(), port), timeout, weight); err != nil {
return fmt.Errorf("could not insert recursor: %w", err)
}
return nil
}
func (ss *Sqlite) UpdateRule(id int, in RuleRow) error {
sql := `UPDATE rules SET
name = ?,
expression = ?,
answerType = ?,
answerValue = ?,
ttl = ?,
weight = ?,
enabled = ?
WHERE id = ?;`
if _, err := ss.Exec(sql, in.Name, in.Value, in.Answer.Type, in.Answer.Value, in.TTL, in.Weight, in.Enabled); err != nil {
return fmt.Errorf("could not update rule with id: %v", id)
}
return nil
}
func (ss *Sqlite) AddRule(rr RuleRow) error {
sql := `INSERT INTO rules (name, expression, answerType, answerValue, ttl, weight, enabled, created)
VALUES (?, ? , ?, ?, ?, ?, 1, ?);`
if _, err := ss.Exec(sql, rr.Name, rr.Value, rr.Answer.Type, rr.Answer.Value, rr.TTL, rr.Weight, time.Now().UTC().Format(ISO8601)); err != nil {
return fmt.Errorf("could not delete rule: %w", err)
}
return nil
}
func (ss *Sqlite) GetRule(ruleId int) (RuleRow, error) {
sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules WHERE id = ?;`
var rr RuleRow
row := ss.QueryRow(sql, ruleId)
var createdTime string
if err := row.Scan(&rr.ID, &rr.Name, &rr.Value, &rr.Answer.Type, &rr.Answer.Value, &rr.TTL, &rr.Weight, &rr.Enabled, &createdTime); err != nil {
return rr, err
}
var err error
rr.Created, err = time.Parse(ISO8601, createdTime)
if err != nil {
return rr, fmt.Errorf("could not parse time: %w", err)
}
return rr, nil
}
func (ss *Sqlite) DeleteRule(ruleId int) error {
if _, err := ss.Exec(`DELETE FROM rules WHERE id = ?;`, ruleId); err != nil {
return fmt.Errorf("could not delete rule: %w", err)
}
return nil
}
func (ss *Sqlite) GetRules() ([]RuleRow, error) {
sql := `SELECT id, name, expression, answerType, answerValue, ttl, weight, enabled, created FROM rules ORDER BY weight ASC;`
rows, err := ss.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
if rerr := rows.Err(); rerr != nil {
return nil, err
}
results := []RuleRow{}
for rows.Next() {
var rule RuleRow
var createdTime string
if err := rows.Scan(
&rule.ID,
&rule.Name,
&rule.Value,
&rule.Answer.Type,
&rule.Answer.Value,
&rule.TTL,
&rule.Weight,
&rule.Enabled,
&createdTime,
); err != nil {
return nil, fmt.Errorf("could not read from db: %w", err)
}
rule.Created, err = time.Parse(ISO8601, createdTime)
if err != nil {
return nil, err
}
results = append(results, rule)
}
return results, nil
}
func (ss *Sqlite) GetLog(in GetLogInput) (GetLogResult, error) {
if in.Limit <= 0 {
in.Limit = 25
}
if in.Start.IsZero() {
in.Start = time.Now().Add(time.Hour * -86400)
}
if in.End.IsZero() {
in.End = time.Now()
}
glr := GetLogResult{
GetLogInput: in,
Logs: []QueryLog{},
}
lpi, err := ss.GetPagingInfo(in)
if err != nil {
return glr, err
}
glr.TotalResults = lpi.Total
glr.PageCount = lpi.PageCount + 1
sql := `
SELECT
started, clientIp, protocol, domain, totalTimeMs,
error, recurseRoundTripTimeMs, recurseUpstreamIp, status
FROM (
SELECT id,
started,
clientIp,
protocol,
domain,
totalTimeMs,
error,
recurseRoundTripTimeMs,
recurseUpstreamIp,
status
FROM log
WHERE strftime('%s', started) >= strftime('%s', ?)
AND strftime('%s', started) <= strftime('%s', ?)
ORDER BY started DESC
) WHERE id <= ? ORDER BY id DESC LIMIT ?;
`
rows, err := ss.DB.Query(sql, in.Start.UTC().Format(ISO8601), in.End.UTC().Format(ISO8601), lpi.FirstItemID, in.Limit)
if err != nil {
return glr, fmt.Errorf("issue with GetLog sql query: %w", err)
}
defer rows.Close()
if rerr := rows.Err(); rerr != nil {
return glr, fmt.Errorf("issue with rows object: %w", rerr)
}
for rows.Next() {
var q QueryLog
var started string
if err := rows.Scan(
&started,
&q.ClientIP,
&q.Protocol,
&q.Domain,
&q.TotalTimeMs,
&q.Error,
&q.RecurseRoundTripTimeMs,
&q.RecurseUpstreamIP,
&q.Status,
); err != nil {
return glr, fmt.Errorf("issues scanning rows: %w", err)
}
if q.Started, err = time.Parse(ISO8601, started); err != nil {
return glr, fmt.Errorf("could not parse time '%s': %w", started, err)
}
glr.Logs = append(glr.Logs, q)
}
return glr, nil
}
type LogPageInfo struct {
Total int
PageCount int
FirstItemID int
}
func (ss *Sqlite) GetPagingInfo(in GetLogInput) (lpi LogPageInfo, err error) {
sql := `
SELECT
COUNT(*) as totalLogsEntries,
COUNT(*) / ? as pageCount,
MAX(id) - ? as firstItemId
FROM
log
WHERE
strftime('%s', started) > strftime('%s', ?)
AND strftime('%s', started) < strftime('%s', ?)
ORDER BY id DESC
`
pageOffset := in.Limit * in.Page
row := ss.QueryRow(sql, in.Limit, pageOffset, in.Start.UTC().Format(ISO8601), in.End.UTC().Format(ISO8601))
if err = row.Scan(&lpi.Total, &lpi.PageCount, &lpi.FirstItemID); err != nil {
return
}
if pageOffset > lpi.Total {
err = errors.New("page number too high")
}
return
}
func initTable(db *sql.DB) error {
sql := `
CREATE TABLE IF NOT EXISTS log (
id INTEGER PRIMARY KEY,
started TEXT NOT NULL,
clientIp TEXT NOT NULL,
protocol TEXT NOT NULL,
domain TEXT NOT NULL,
totalTimeMs int NOT NULL,
error TEXT,
recurseRoundTripTimeMs INT,
recurseUpStreamIP TEXT,
status TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS rules (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
expression TEXT NOT NULL,
answerType TEXT NOT NULL,
answerValue TEXT NOT NULL,
ttl INT NOT NULL,
weight INT NOT NULL,
enabled INT NOT NULL,
created TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_name ON rules (name);
CREATE UNIQUE INDEX IF NOT EXISTS idx_rules_expression ON rules (expression);
CREATE TABLE IF NOT EXISTS recursors (
id INTEGER PRIMARY KEY,
ipAddress TEXT NOT NULL,
timeoutMs INT NOT NULL,
weight INT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_recursors_ipAddress ON recursors (ipAddress);
CREATE TABLE IF NOT EXISTS ruleslist (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
url TEXT NOT NULL,
lastLoadedTs TEXT NOT NULL,
weight INT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_name ON ruleslist (name);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_url ON ruleslist (url);
CREATE TABLE IF NOT EXISTS ruleslist_entry (
id INTEGER PRIMARY KEY,
ruleslistId INTEGER NOT NULL,
expression TEXT NOT NULL,
ipAddress TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_entry_expression ON ruleslist_entry (expression);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ruleslist_entry_ruleslistId ON ruleslist_entry (id, ruleslistId);
`
if _, err := db.Exec(sql); err != nil {
return fmt.Errorf("could not initialize db: %w", err)
}
return nil
}