diff --git a/database/admins.go b/database/admins.go index fd57d46..a06a9bd 100644 --- a/database/admins.go +++ b/database/admins.go @@ -5,49 +5,91 @@ import ( ) type Admin struct { - ID uint64 `gorm:"primaryKey"` - Username string `gorm:"unique;notNull"` - Bcrypt string `gorm:"notNull"` + ID int + Username string + Bcrypt string } -func CountAdmins() int64 { - var admins []Admin +func CountAdmins() (int64, error) { var count int64 - DB.Model(&admins).Count(&count) - return count + query := `SELECT COUNT(*) FROM admins` + if err := DB.QueryRow(query).Scan(&count); err != nil { + return 0, err + } + + return count, nil } -func CreateAdmin(username string, password string) (Admin, error) { +func CreateAdmin(username string, password string) (*Admin, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), 10) if err != nil { - return Admin{}, err + return nil, err } - admin := Admin{ - Username: username, - Bcrypt: string(hash), - } - result := DB.Create(&admin) + query := ` + INSERT INTO admins(username, bcrypt) + VALUES (?, ?) + RETURNING id + ` - if result.Error != nil { - return Admin{}, result.Error - } - - return admin, nil -} - -func AuthorizeAdmin(username string, password string) (Admin, error) { var admin Admin - result := DB.Where("username = ?", username).First(&admin) + admin.Username = username + admin.Bcrypt = string(hash) - if result.Error != nil { - return Admin{}, result.Error - } + err = DB. + QueryRow(query, admin.Username, admin.Bcrypt). + Scan(&admin.ID) - err := bcrypt.CompareHashAndPassword([]byte(admin.Bcrypt), []byte(password)) if err != nil { - return Admin{}, err + return nil, err } - return admin, nil + return &admin, nil +} + +func GetAdminIfPasswordMatches(username string, password string) (*Admin, error) { + query := ` + SELECT id, username, bcrypt + FROM admins + WHERE username = ? + ` + + var admin Admin + err := DB. + QueryRow(query, username). + Scan(&admin.ID, &admin.Username, &admin.Bcrypt) + + if err != nil { + return nil, err + } + + err = bcrypt.CompareHashAndPassword([]byte(admin.Bcrypt), []byte(password)) + if err != nil { + return nil, err + } + + return &admin, nil +} + +func DeleteAdmin(id int) error { + query := ` + DELETE FROM admins + WHERE id = ? + ` + + res, err := DB.Exec(query, id) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected != 1 { + return ErrWrongNumberOfAffectedRows + } + + return nil } diff --git a/database/admins_test.go b/database/admins_test.go new file mode 100644 index 0000000..b5251c4 --- /dev/null +++ b/database/admins_test.go @@ -0,0 +1,58 @@ +package database + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func TestAdmins(t *testing.T) { + initTestDatabase(t) + defer deleteTestDatabase(t) + + // We should have no admins. + count, err := CountAdmins() + if err != nil { + t.Fatal(err) + } + + if count != 0 { + t.Fatal("user count is not zero") + } + + // Create the first user. + username := "test" + password := "test" + admin, err := CreateAdmin(username, password) + if err != nil { + t.Fatal(err) + } + + // Check password and get admin. + dbAdmin, err := GetAdminIfPasswordMatches(username, password) + if err != nil { + t.Fatal(err) + } + if dbAdmin.ID != admin.ID { + t.Fatal("wrong admin id") + } + + // Check wrong password handling. + if _, err := GetAdminIfPasswordMatches("test", "wrong-password"); err == nil { + t.Fatal("wrong password was accepted") + } + + // Count users again. + count, err = CountAdmins() + if err != nil { + t.Fatal(err) + } + + if count != 1 { + t.Fatal("user count is not one") + } + + // Delete user. + if err := DeleteAdmin(admin.ID); err != nil { + t.Fatal(err) + } +} diff --git a/database/connection.go b/database/connection.go new file mode 100644 index 0000000..76ce278 --- /dev/null +++ b/database/connection.go @@ -0,0 +1,20 @@ +package database + +import ( + "database/sql" + _ "github.com/mattn/go-sqlite3" + "github.com/ordinary-dev/phoenix/config" +) + +var DB *sql.DB + +func EstablishDatabaseConnection(cfg *config.Config) error { + var err error + DB, err = sql.Open("sqlite3", cfg.DBPath) + + if err := DB.Ping(); err != nil { + return err + } + + return err +} diff --git a/database/connection_test.go b/database/connection_test.go new file mode 100644 index 0000000..a361def --- /dev/null +++ b/database/connection_test.go @@ -0,0 +1,28 @@ +package database + +import ( + "database/sql" + _ "github.com/mattn/go-sqlite3" + "os" + "testing" +) + +const TEST_DB_PATH = "/tmp/phoenix.sqlite3" + +func initTestDatabase(t *testing.T) { + var err error + DB, err = sql.Open("sqlite3", TEST_DB_PATH) + if err != nil { + t.Fatal(err) + } + + if err := ApplyMigrations(); err != nil { + t.Fatal(err) + } +} + +func deleteTestDatabase(t *testing.T) { + if err := os.Remove(TEST_DB_PATH); err != nil { + t.Fatal(err) + } +} diff --git a/database/db.go b/database/db.go deleted file mode 100644 index 22e551d..0000000 --- a/database/db.go +++ /dev/null @@ -1,22 +0,0 @@ -package database - -import ( - "github.com/ordinary-dev/phoenix/config" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -var DB *gorm.DB - -func EstablishDatabaseConnection(cfg *config.Config) error { - var err error - DB, err = gorm.Open(sqlite.Open(cfg.DBPath), &gorm.Config{}) - if err != nil { - return err - } - - // Migrate the schema - DB.AutoMigrate(&Admin{}, &Group{}, &Link{}) - - return nil -} diff --git a/database/errors.go b/database/errors.go new file mode 100644 index 0000000..e1e9715 --- /dev/null +++ b/database/errors.go @@ -0,0 +1,9 @@ +package database + +import ( + "errors" +) + +var ( + ErrWrongNumberOfAffectedRows = errors.New("wrong number of affected rows") +) diff --git a/database/groups.go b/database/groups.go index cdeae45..3f744ea 100644 --- a/database/groups.go +++ b/database/groups.go @@ -1,7 +1,106 @@ package database type Group struct { - ID uint64 `gorm:"primaryKey"` - Name string `gorm:"unique,notNull"` - Links []Link `gorm:"constraint:OnDelete:CASCADE;"` + ID int + Name string + Links []Link +} + +func GetGroupsWithLinks() ([]Group, error) { + query := ` + SELECT id, name + FROM groups + ORDER BY groups.id + ` + + rows, err := DB.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + var groups []Group + for rows.Next() { + var group Group + if err := rows.Scan(&group.ID, &group.Name); err != nil { + return nil, err + } + groups = append(groups, group) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + for i := range groups { + groups[i].Links, err = GetLinksFromGroup(groups[i].ID) + if err != nil { + return nil, err + } + } + + return groups, nil +} + +// Create a new group in the database. +// The function fills in the ID. +func CreateGroup(group *Group) error { + query := ` + INSERT INTO groups (name) + VALUES (?) + RETURNING id + ` + + if err := DB.QueryRow(query, group.Name).Scan(&group.ID); err != nil { + return err + } + + return nil +} + +func UpdateGroup(id int, name string) error { + query := ` + UPDATE groups + SET name = ? + WHERE id = ? + ` + + res, err := DB.Exec(query, name, id) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected != 1 { + return ErrWrongNumberOfAffectedRows + } + + return nil +} + +func DeleteGroup(groupID int) error { + query := ` + DELETE FROM groups + WHERE id = ? + ` + + res, err := DB.Exec(query, groupID) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected != 1 { + return ErrWrongNumberOfAffectedRows + } + + return nil } diff --git a/database/groups_test.go b/database/groups_test.go new file mode 100644 index 0000000..a280a59 --- /dev/null +++ b/database/groups_test.go @@ -0,0 +1,46 @@ +package database + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func TestGroups(t *testing.T) { + initTestDatabase(t) + defer deleteTestDatabase(t) + + // Create the first group. + group := Group{ + Name: "test", + } + if err := CreateGroup(&group); err != nil { + t.Fatal(err) + } + if group.ID == 0 { + t.Fatal("group id is zero") + } + + // Update group. + if err := UpdateGroup(group.ID, "new-name"); err != nil { + t.Fatal(err) + } + + // Read groups. + groupList, err := GetGroupsWithLinks() + if err != nil { + t.Fatal(err) + } + + if len(groupList) != 1 { + t.Fatal("group list length is not one") + } + + if groupList[0].Name != "new-name" { + t.Fatal("wrong group name") + } + + // Delete group. + if err := DeleteGroup(group.ID); err != nil { + t.Fatal(err) + } +} diff --git a/database/links.go b/database/links.go index 3b44042..aed805c 100644 --- a/database/links.go +++ b/database/links.go @@ -1,9 +1,124 @@ package database type Link struct { - ID uint64 `gorm:"primaryKey"` - Name string `gorm:"notNull"` - Href string `gorm:"notNull"` - GroupID uint64 `gorm:"notNull"` + ID int + Name string + Href string + GroupID int Icon *string } + +func GetLinksFromGroup(groupID int) ([]Link, error) { + query := ` + SELECT id, name, href, group_id, icon + FROM links + WHERE group_id = ? + ORDER BY id + ` + + rows, err := DB.Query(query, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + + var links []Link + for rows.Next() { + var link Link + if err := rows.Scan(&link.ID, &link.Name, &link.Href, &link.GroupID, &link.Icon); err != nil { + return nil, err + } + links = append(links, link) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return links, nil +} + +func GetLink(id int) (*Link, error) { + query := ` + SELECT id, name, href, group_id, icon + FROM links + WHERE id = ? + ` + + var link Link + err := DB. + QueryRow(query, id). + Scan(&link.ID, &link.Name, &link.Href, &link.GroupID, &link.Icon) + if err != nil { + return nil, err + } + + return &link, nil +} + +// Create a new link in the database. +// The function fills in the ID. +func CreateLink(link *Link) error { + query := ` + INSERT INTO links (name, href, group_id, icon) + VALUES (?, ?, ?, ?) + RETURNING id + ` + + err := DB. + QueryRow(query, link.Name, link.Href, link.GroupID, link.Icon). + Scan(&link.ID) + + if err != nil { + return err + } + + return nil +} + +func UpdateLink(link *Link) error { + query := ` + UPDATE links + SET name = ?, href = ?, group_id = ?, icon = ? + WHERE id = ? + ` + + res, err := DB.Exec(query, link.Name, link.Href, link.GroupID, link.Icon, link.ID) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected != 1 { + return ErrWrongNumberOfAffectedRows + } + + return nil +} + +func DeleteLink(linkID int) error { + query := ` + DELETE FROM links + WHERE id = ? + ` + + res, err := DB.Exec(query, linkID) + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + + if rowsAffected != 1 { + return ErrWrongNumberOfAffectedRows + } + + return nil +} diff --git a/database/links_test.go b/database/links_test.go new file mode 100644 index 0000000..e1c0a21 --- /dev/null +++ b/database/links_test.go @@ -0,0 +1,50 @@ +package database + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func TestLinks(t *testing.T) { + initTestDatabase(t) + defer deleteTestDatabase(t) + + // Create the first group. + group := Group{ + Name: "test", + } + if err := CreateGroup(&group); err != nil { + t.Fatal(err) + } + + // Create the first link. + icon := "test/icon" + link := Link{ + Name: "test", + Href: "/test", + GroupID: group.ID, + Icon: &icon, + } + if err := CreateLink(&link); err != nil { + t.Fatal(err) + } + if link.ID == 0 { + t.Fatal("link id is zero") + } + + // Update link. + link.Href = "/new-href" + if err := UpdateLink(&link); err != nil { + t.Fatal(err) + } + + // Delete link. + if err := DeleteLink(link.ID); err != nil { + t.Fatal(err) + } + + // Delete group. + if err := DeleteGroup(group.ID); err != nil { + t.Fatal(err) + } +} diff --git a/database/migrations.go b/database/migrations.go new file mode 100644 index 0000000..e1854e0 --- /dev/null +++ b/database/migrations.go @@ -0,0 +1,98 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + + log "github.com/sirupsen/logrus" +) + +// List of migrations that should be applied. +// Migration ID = index + 1. +var migrations = []string{ + `CREATE TABLE IF NOT EXISTS admins ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + bcrypt TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS groups ( + id INTEGER PRIMARY KEY, + name TEXT + )`, + `CREATE TABLE IF NOT EXISTS links ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + href TEXT NOT NULL, + group_id INTEGER NOT NULL, + icon TEXT, + CONSTRAINT fk_groups_links + FOREIGN KEY (group_id) + REFERENCES groups(id) + ON DELETE CASCADE + )`, +} + +func ApplyMigrations() error { + // Create a table to record applied migrations and retrieve the saved data. + _, err := DB.Exec(`CREATE TABLE IF NOT EXISTS migrations ( + version INTEGER NOT NULL DEFAULT 0 + )`) + if err != nil { + return err + } + + var currentVersion int + err = DB. + QueryRow("SELECT version FROM migrations"). + Scan(¤tVersion) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return err + } + + // The table is empty, create a record. + _, err = DB.Exec("INSERT INTO migrations (version) VALUES (0)") + if err != nil { + return err + } + } + + // Apply all migrations. + for i, migration := range migrations { + migrationID := i + 1 + if migrationID <= currentVersion { + continue + } + + if err := applyMigration(migrationID, migration); err != nil { + return fmt.Errorf("migration #%d: %w", migrationID, err) + } + + log.Infof("Migration #%v has been applied", migrationID) + } + + return nil +} + +func applyMigration(migrationID int, query string) error { + tx, err := DB.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec(query); err != nil { + return fmt.Errorf("error when applying migration: %w", err) + } + + if _, err := tx.Exec("UPDATE migrations SET version = ?", migrationID); err != nil { + return fmt.Errorf("error when updating schema version: %w", err) + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/database/migrations_test.go b/database/migrations_test.go new file mode 100644 index 0000000..411f21d --- /dev/null +++ b/database/migrations_test.go @@ -0,0 +1,16 @@ +package database + +import ( + _ "github.com/mattn/go-sqlite3" + "testing" +) + +func TestMigrations(t *testing.T) { + initTestDatabase(t) + defer deleteTestDatabase(t) + + // We should be able to call the function multiple times. + if err := ApplyMigrations(); err != nil { + t.Fatal(err) + } +} diff --git a/go.mod b/go.mod index 9f2208b..ba7560b 100644 --- a/go.mod +++ b/go.mod @@ -6,16 +6,12 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/joho/godotenv v1.5.1 github.com/kelseyhightower/envconfig v1.4.0 + github.com/mattn/go-sqlite3 v1.14.22 github.com/sirupsen/logrus v1.9.3 golang.org/x/crypto v0.21.0 - gorm.io/driver/sqlite v1.5.5 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde ) require ( - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/stretchr/testify v1.8.3 // indirect golang.org/x/sys v0.18.0 // indirect ) diff --git a/go.sum b/go.sum index 8e9139c..d4a3009 100644 --- a/go.sum +++ b/go.sum @@ -3,16 +3,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= -github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= -github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -30,7 +26,3 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= -gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/main.go b/main.go index 4a1be63..fa55756 100644 --- a/main.go +++ b/main.go @@ -4,43 +4,55 @@ import ( "github.com/ordinary-dev/phoenix/config" "github.com/ordinary-dev/phoenix/database" "github.com/ordinary-dev/phoenix/views" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) func main() { // Configure logger - logrus.SetFormatter(&logrus.TextFormatter{ + log.SetFormatter(&log.TextFormatter{ FullTimestamp: true, }) // Read config cfg, err := config.GetConfig() if err != nil { - logrus.Fatalf("%v", err) + log.Fatal(err) } // Set log level logLevel := cfg.GetLogLevel() - logrus.SetLevel(logLevel) - logrus.Infof("Setting log level to %v", logLevel) + log.SetLevel(logLevel) + log.Infof("Setting log level to %v", logLevel) // Connect to the database err = database.EstablishDatabaseConnection(cfg) if err != nil { - logrus.Fatalf("%v", err) + log.Fatal(err) + } + + // Apply migrations. + if err := database.ApplyMigrations(); err != nil { + log.Fatal(err) } // Create the first user - if cfg.DefaultUsername != "" && cfg.DefaultPassword != "" && database.CountAdmins() < 1 { - _, err := database.CreateAdmin(cfg.DefaultUsername, cfg.DefaultPassword) + if cfg.DefaultUsername != "" && cfg.DefaultPassword != "" { + adminCount, err := database.CountAdmins() if err != nil { - logrus.Errorf("%v", err) + log.Fatal(err) + } + + if adminCount < 1 { + _, err := database.CreateAdmin(cfg.DefaultUsername, cfg.DefaultPassword) + if err != nil { + log.Fatal(err) + } } } server, err := views.GetHttpServer() if err != nil { - logrus.Fatal(err) + log.Fatal(err) } server.ListenAndServe() diff --git a/views/middleware/auth.go b/views/middleware/auth.go index 8f586f3..1189e04 100644 --- a/views/middleware/auth.go +++ b/views/middleware/auth.go @@ -57,7 +57,13 @@ func RequireAuth(next http.Handler) http.Handler { // Most likely the user is not authorized. if err != nil { - if database.CountAdmins() < 1 { + count, err := database.CountAdmins() + if err != nil { + pages.ShowError(w, http.StatusInternalServerError, err) + return + } + + if count < 1 { http.Redirect(w, r, "/registration", http.StatusFound) } else { http.Redirect(w, r, "/signin", http.StatusFound) diff --git a/views/pages/groups.go b/views/pages/groups.go index fa330fa..dbe4260 100644 --- a/views/pages/groups.go +++ b/views/pages/groups.go @@ -14,8 +14,8 @@ func CreateGroup(w http.ResponseWriter, r *http.Request) { Name: r.FormValue("groupName"), } - if result := database.DB.Create(&group); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + if err := database.CreateGroup(&group); err != nil { + ShowError(w, http.StatusInternalServerError, err) return } @@ -24,37 +24,30 @@ func CreateGroup(w http.ResponseWriter, r *http.Request) { } func UpdateGroup(w http.ResponseWriter, r *http.Request) { - id, err := strconv.ParseUint(r.PathValue("id"), 10, 64) + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) if err != nil { ShowError(w, http.StatusBadRequest, err) return } - var group database.Group - if result := database.DB.First(&group, id); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) - return - } - - group.Name = r.FormValue("groupName") - if result := database.DB.Save(&group); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + if err := database.UpdateGroup(int(id), r.FormValue("groupName")); err != nil { + ShowError(w, http.StatusInternalServerError, err) return } // This page is called from the settings, return the user back. - http.Redirect(w, r, fmt.Sprintf("/settings#group-%v", group.ID), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("/settings#group-%v", id), http.StatusFound) } func DeleteGroup(w http.ResponseWriter, r *http.Request) { - id, err := strconv.ParseUint(r.PathValue("id"), 10, 64) + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) if err != nil { ShowError(w, http.StatusBadRequest, err) return } - if result := database.DB.Delete(&database.Group{}, id); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + if err := database.DeleteGroup(int(id)); err != nil { + ShowError(w, http.StatusInternalServerError, err) return } diff --git a/views/pages/index.go b/views/pages/index.go index 5f36973..29ff2f0 100644 --- a/views/pages/index.go +++ b/views/pages/index.go @@ -8,19 +8,13 @@ import ( ) func ShowMainPage(w http.ResponseWriter, _ *http.Request) { - // Get a list of groups with links - var groups []database.Group - result := database.DB. - Model(&database.Group{}). - Preload("Links"). - Find(&groups) - - if result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + groups, err := database.GetGroupsWithLinks() + if err != nil { + ShowError(w, http.StatusInternalServerError, err) return } - err := Render("index.html.tmpl", w, map[string]any{ + err = Render("index.html.tmpl", w, map[string]any{ "description": "Self-hosted start page.", "groups": groups, }) diff --git a/views/pages/links.go b/views/pages/links.go index d40ad37..0a5f9c6 100644 --- a/views/pages/links.go +++ b/views/pages/links.go @@ -9,7 +9,7 @@ import ( ) func CreateLink(w http.ResponseWriter, r *http.Request) { - groupID, err := strconv.ParseUint(r.FormValue("groupID"), 10, 32) + groupID, err := strconv.Atoi(r.FormValue("groupID")) if err != nil { ShowError(w, http.StatusBadRequest, err) return @@ -26,8 +26,8 @@ func CreateLink(w http.ResponseWriter, r *http.Request) { } else { link.Icon = &icon } - if result := database.DB.Create(&link); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + if err := database.CreateLink(&link); err != nil { + ShowError(w, http.StatusInternalServerError, err) return } @@ -36,15 +36,15 @@ func CreateLink(w http.ResponseWriter, r *http.Request) { } func UpdateLink(w http.ResponseWriter, r *http.Request) { - id, err := strconv.ParseUint(r.PathValue("id"), 10, 64) + id, err := strconv.Atoi(r.PathValue("id")) if err != nil { ShowError(w, http.StatusBadRequest, err) return } - var link database.Link - if result := database.DB.First(&link, id); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + link, err := database.GetLink(id) + if err != nil { + ShowError(w, http.StatusInternalServerError, err) return } @@ -56,8 +56,9 @@ func UpdateLink(w http.ResponseWriter, r *http.Request) { } else { link.Icon = &icon } - if result := database.DB.Save(&link); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + + if err := database.UpdateLink(link); err != nil { + ShowError(w, http.StatusInternalServerError, err) return } @@ -66,14 +67,14 @@ func UpdateLink(w http.ResponseWriter, r *http.Request) { } func DeleteLink(w http.ResponseWriter, r *http.Request) { - id, err := strconv.ParseUint(r.PathValue("id"), 10, 64) + id, err := strconv.Atoi(r.PathValue("id")) if err != nil { ShowError(w, http.StatusBadRequest, err) return } - if result := database.DB.Delete(&database.Link{}, id); result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + if err := database.DeleteLink(id); err != nil { + ShowError(w, http.StatusInternalServerError, err) return } diff --git a/views/pages/registration.go b/views/pages/registration.go index 0dcb34a..2182aea 100644 --- a/views/pages/registration.go +++ b/views/pages/registration.go @@ -9,7 +9,13 @@ import ( ) func ShowRegistrationForm(w http.ResponseWriter, _ *http.Request) { - if database.CountAdmins() > 0 { + userCount, err := database.CountAdmins() + if err != nil { + ShowError(w, http.StatusInternalServerError, err) + return + } + + if userCount > 0 { ShowError(w, http.StatusBadRequest, errors.New("at least 1 user already exists")) return } @@ -23,7 +29,13 @@ func ShowRegistrationForm(w http.ResponseWriter, _ *http.Request) { } func CreateUser(w http.ResponseWriter, r *http.Request) { - if database.CountAdmins() > 0 { + userCount, err := database.CountAdmins() + if err != nil { + ShowError(w, http.StatusInternalServerError, err) + return + } + + if userCount > 0 { ShowError(w, http.StatusBadRequest, errors.New("at least 1 user already exists")) return } @@ -31,7 +43,7 @@ func CreateUser(w http.ResponseWriter, r *http.Request) { // Try to create a user. username := r.FormValue("username") password := r.FormValue("password") - _, err := database.CreateAdmin(username, password) + _, err = database.CreateAdmin(username, password) if err != nil { ShowError(w, http.StatusInternalServerError, err) return diff --git a/views/pages/settings.go b/views/pages/settings.go index 4fd3d5e..7eefd06 100644 --- a/views/pages/settings.go +++ b/views/pages/settings.go @@ -7,15 +7,9 @@ import ( ) func ShowSettings(w http.ResponseWriter, _ *http.Request) { - // Get a list of groups with links - var groups []database.Group - result := database.DB. - Model(&database.Group{}). - Preload("Links"). - Find(&groups) - - if result.Error != nil { - ShowError(w, http.StatusInternalServerError, result.Error) + groups, err := database.GetGroupsWithLinks() + if err != nil { + ShowError(w, http.StatusInternalServerError, err) return } diff --git a/views/pages/signin.go b/views/pages/signin.go index 2190d71..c848bf7 100644 --- a/views/pages/signin.go +++ b/views/pages/signin.go @@ -25,7 +25,7 @@ func AuthorizeUser(w http.ResponseWriter, r *http.Request) { // Check credentials. username := r.FormValue("username") password := r.FormValue("password") - _, err := database.AuthorizeAdmin(username, password) + _, err := database.GetAdminIfPasswordMatches(username, password) if err != nil { ShowError(w, http.StatusUnauthorized, err) return