diff --git a/api/internal/handlers/stats.go b/api/internal/handlers/stats.go index 247c989..80c6d3e 100644 --- a/api/internal/handlers/stats.go +++ b/api/internal/handlers/stats.go @@ -19,14 +19,14 @@ func (s Stats) Subscribe(c chi.Router) error { r.Use(services.NewJWTVerifier()) r.Use(jwtauth.Authenticator) - r.Put("/stats/weight", s.logWeight) - r.Get("/stats/weight", s.getWeightLog) + r.Put("/stats", s.putLog) + r.Get("/stats", s.getLog) }) return nil } -func (s Stats) getWeightLog(res http.ResponseWriter, req *http.Request) { +func (s Stats) getLog(res http.ResponseWriter, req *http.Request) { fail := APIResp{Payload: "Unauthorized", Status: http.StatusUnauthorized} userId, err := services.GetUserIDFromClaims(req) if err != nil { @@ -35,8 +35,10 @@ func (s Stats) getWeightLog(res http.ResponseWriter, req *http.Request) { return } - startStr := req.URL.Query().Get("start") - endStr := req.URL.Query().Get("end") + queryStrs := req.URL.Query() + startStr := queryStrs.Get("start") + endStr := queryStrs.Get("end") + typeStr := queryStrs.Get("type") start := time.Now().UTC().Add(-time.Hour * 24 * 30) @@ -49,7 +51,7 @@ func (s Stats) getWeightLog(res http.ResponseWriter, req *http.Request) { end, _ = time.Parse(time.RFC3339, endStr) } - wlog, err := s.Conn.GetWeightLog(userId, start, end) + wlog, err := s.Conn.GetLog(userId, start, end, services.ParseLogTypes(typeStr)) if err != nil { log.Println(err) fail.Write(res) @@ -59,10 +61,10 @@ func (s Stats) getWeightLog(res http.ResponseWriter, req *http.Request) { APIResp{ Success: true, Payload: struct { - Log []services.WeightLog `json:"log"` - Count int `json:"count"` - Start time.Time `json:"start"` - End time.Time `json:"end"` + Log []services.StatLog `json:"log"` + Count int `json:"count"` + Start time.Time `json:"start"` + End time.Time `json:"end"` }{ Log: wlog, Count: len(wlog), @@ -73,7 +75,7 @@ func (s Stats) getWeightLog(res http.ResponseWriter, req *http.Request) { } -func (s Stats) logWeight(res http.ResponseWriter, req *http.Request) { +func (s Stats) putLog(res http.ResponseWriter, req *http.Request) { fail := APIResp{Payload: "Unauthorized", Status: http.StatusUnauthorized} userId, err := services.GetUserIDFromClaims(req) if err != nil { @@ -83,7 +85,7 @@ func (s Stats) logWeight(res http.ResponseWriter, req *http.Request) { } type logWeightInput struct { - services.WeightLog + services.StatLog } var in logWeightInput @@ -95,7 +97,7 @@ func (s Stats) logWeight(res http.ResponseWriter, req *http.Request) { return } - if err := s.Conn.LogWeight(userId, in.WeightLog); err != nil { + if err := s.Conn.AddLog(userId, in.StatLog); err != nil { log.Println(err) fail.Write(res) return diff --git a/api/internal/services/db.go b/api/internal/services/db.go index bea2136..241223b 100644 --- a/api/internal/services/db.go +++ b/api/internal/services/db.go @@ -7,6 +7,7 @@ import ( "log" "time" + "github.com/lib/pq" _ "github.com/lib/pq" ) @@ -14,9 +15,10 @@ var dbConn *sql.DB type DB interface { // VerifyEmail(email string) error + // ResetPassword(email string) (string, error) // RegisterUser(email, username, password string) (User, error) - GetWeightLog(int64, time.Time, time.Time) ([]WeightLog, error) - LogWeight(int64, WeightLog) error + GetLog(int64, time.Time, time.Time, []LogType) ([]StatLog, error) + AddLog(int64, StatLog) error UpdateProfile(int64, User) error GetProfileByID(int64) (User, error) GetProfileByEmail(string) (User, error) @@ -38,44 +40,45 @@ func Connect(host, user, password string) (DB, error) { type pgdb struct{ *sql.DB } -func (db pgdb) GetWeightLog(id int64, start, end time.Time) ([]WeightLog, error) { +func (db pgdb) GetLog(id int64, start, end time.Time, lt []LogType) ([]StatLog, error) { - getWeightLogSql := ` + getLogSql := ` SELECT - value, recordedTs - FROM stats.weightlog + value, logType, recordedTs + FROM stats.log WHERE userId = $1 AND recordedTs >= $2 AND recordedTs <= $3 + AND logType = ANY($4) ORDER BY recordedTs DESC;` - var results []WeightLog + var results []StatLog - rows, err := db.Query(getWeightLogSql, id, start, end) + rows, err := db.Query(getLogSql, id, start, end, pq.Array(lt)) if err != nil { return nil, err } + defer rows.Close() if err := rows.Err(); err != nil { return nil, err } for rows.Next() { - var wl WeightLog - if err := rows.Scan(&wl.Value, &wl.RecordedTS); err != nil { - log.Printf("could no scane val for weightlog: %v", err) + var sl StatLog + if err := rows.Scan(&sl.Value, &sl.Type, &sl.RecordedTS); err != nil { + log.Printf("could not scan value for statlog: %v", err) } - results = append(results, wl) - } - defer rows.Close() + results = append(results, sl) + } return results, nil } -func (db pgdb) LogWeight(id int64, w WeightLog) error { - logWeightSql := ` - INSERT INTO stats.weightlog (userId, value, recordedTs) - VALUES ($1, $2, $3); +func (db pgdb) AddLog(id int64, sl StatLog) error { + statLogInsertSql := ` + INSERT INTO stats.log (userId, value, logType, recordedTs) + VALUES ($1, $2, $3, $4); ` tx, err := db.BeginTx(context.Background(), nil) @@ -84,7 +87,7 @@ func (db pgdb) LogWeight(id int64, w WeightLog) error { } defer tx.Commit() - if _, err := tx.Exec(logWeightSql, id, w.Value, w.RecordedTS); err != nil { + if _, err := tx.Exec(statLogInsertSql, id, sl.Value, sl.Type, sl.RecordedTS); err != nil { tx.Rollback() return err } diff --git a/api/internal/services/stats.go b/api/internal/services/stats.go index 5bd7d90..48f7094 100644 --- a/api/internal/services/stats.go +++ b/api/internal/services/stats.go @@ -1,12 +1,44 @@ package services import ( + "strings" "time" ) -type WeightLog struct { - ID uint64 `json:"id"` - UserID uint64 `json:"userId"` +var ( + AllLogTypes = []LogType{ + LogType("Weight"), + LogType("Height"), + LogType("Calories"), + } +) + +type LogType string + +func ParseLogTypes(lts string) []LogType { + if lts == "" { + return AllLogTypes + } + + ltsArr := []LogType{} + for _, lt := range strings.Split(lts, ",") { + lt_normalized := strings.Title(strings.ToLower(lt)) + ltsArr = append(ltsArr, LogType(lt_normalized)) + } + + return ltsArr +} + +const ( + Weight = LogType("Weight") + Height = LogType("Height") + Calories = LogType("Calories") +) + +type StatLog struct { + ID uint64 `json:"id,omitempty"` + UserID uint64 `json:"userId,omitempty"` Value float64 `json:"value"` + Type LogType `json:"type"` RecordedTS time.Time `json:"recordedTs"` } diff --git a/api/main.go b/api/main.go index 42a6a5c..00e66a6 100644 --- a/api/main.go +++ b/api/main.go @@ -27,6 +27,8 @@ func main() { r := chi.NewRouter() r.Use(middleware.Logger) + r.Use(middleware.Recoverer) + db, err := services.Connect(cfg.PG_Host, cfg.PG_Username, cfg.PG_Password) if err != nil { log.Fatal(err)