chore: migrate to database/sql

Since I started simplifying, I decided to abandon ORM.

I won’t say that this makes much sense, everything works more or less as before.
Except that the size of the program has decreased slightly again, by about a megabyte.
This commit is contained in:
Ivan R. 2024-03-26 00:40:52 +05:00
parent 66701b2687
commit 5aa2cee5b1
No known key found for this signature in database
GPG key ID: 56C7BAAE859B302C
22 changed files with 694 additions and 135 deletions

View file

@ -5,49 +5,91 @@ import (
)
type Admin struct {
ID uint64 `gorm:"primaryKey"`
Username string `gorm:"unique;notNull"`
Bcrypt string `gorm:"notNull"`
ID int
Username string
Bcrypt string
}
func CountAdmins() int64 {
var admins []Admin
func CountAdmins() (int64, error) {
var count int64
DB.Model(&admins).Count(&count)
return count
query := `SELECT COUNT(*) FROM admins`
if err := DB.QueryRow(query).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func CreateAdmin(username string, password string) (Admin, error) {
func CreateAdmin(username string, password string) (*Admin, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), 10)
if err != nil {
return Admin{}, err
return nil, err
}
admin := Admin{
Username: username,
Bcrypt: string(hash),
}
result := DB.Create(&admin)
query := `
INSERT INTO admins(username, bcrypt)
VALUES (?, ?)
RETURNING id
`
if result.Error != nil {
return Admin{}, result.Error
}
return admin, nil
}
func AuthorizeAdmin(username string, password string) (Admin, error) {
var admin Admin
result := DB.Where("username = ?", username).First(&admin)
admin.Username = username
admin.Bcrypt = string(hash)
if result.Error != nil {
return Admin{}, result.Error
}
err = DB.
QueryRow(query, admin.Username, admin.Bcrypt).
Scan(&admin.ID)
err := bcrypt.CompareHashAndPassword([]byte(admin.Bcrypt), []byte(password))
if err != nil {
return Admin{}, err
return nil, err
}
return admin, nil
return &admin, nil
}
func GetAdminIfPasswordMatches(username string, password string) (*Admin, error) {
query := `
SELECT id, username, bcrypt
FROM admins
WHERE username = ?
`
var admin Admin
err := DB.
QueryRow(query, username).
Scan(&admin.ID, &admin.Username, &admin.Bcrypt)
if err != nil {
return nil, err
}
err = bcrypt.CompareHashAndPassword([]byte(admin.Bcrypt), []byte(password))
if err != nil {
return nil, err
}
return &admin, nil
}
func DeleteAdmin(id int) error {
query := `
DELETE FROM admins
WHERE id = ?
`
res, err := DB.Exec(query, id)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return ErrWrongNumberOfAffectedRows
}
return nil
}

58
database/admins_test.go Normal file
View file

@ -0,0 +1,58 @@
package database
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func TestAdmins(t *testing.T) {
initTestDatabase(t)
defer deleteTestDatabase(t)
// We should have no admins.
count, err := CountAdmins()
if err != nil {
t.Fatal(err)
}
if count != 0 {
t.Fatal("user count is not zero")
}
// Create the first user.
username := "test"
password := "test"
admin, err := CreateAdmin(username, password)
if err != nil {
t.Fatal(err)
}
// Check password and get admin.
dbAdmin, err := GetAdminIfPasswordMatches(username, password)
if err != nil {
t.Fatal(err)
}
if dbAdmin.ID != admin.ID {
t.Fatal("wrong admin id")
}
// Check wrong password handling.
if _, err := GetAdminIfPasswordMatches("test", "wrong-password"); err == nil {
t.Fatal("wrong password was accepted")
}
// Count users again.
count, err = CountAdmins()
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatal("user count is not one")
}
// Delete user.
if err := DeleteAdmin(admin.ID); err != nil {
t.Fatal(err)
}
}

20
database/connection.go Normal file
View file

@ -0,0 +1,20 @@
package database
import (
"database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/ordinary-dev/phoenix/config"
)
var DB *sql.DB
func EstablishDatabaseConnection(cfg *config.Config) error {
var err error
DB, err = sql.Open("sqlite3", cfg.DBPath)
if err := DB.Ping(); err != nil {
return err
}
return err
}

View file

@ -0,0 +1,28 @@
package database
import (
"database/sql"
_ "github.com/mattn/go-sqlite3"
"os"
"testing"
)
const TEST_DB_PATH = "/tmp/phoenix.sqlite3"
func initTestDatabase(t *testing.T) {
var err error
DB, err = sql.Open("sqlite3", TEST_DB_PATH)
if err != nil {
t.Fatal(err)
}
if err := ApplyMigrations(); err != nil {
t.Fatal(err)
}
}
func deleteTestDatabase(t *testing.T) {
if err := os.Remove(TEST_DB_PATH); err != nil {
t.Fatal(err)
}
}

View file

@ -1,22 +0,0 @@
package database
import (
"github.com/ordinary-dev/phoenix/config"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var DB *gorm.DB
func EstablishDatabaseConnection(cfg *config.Config) error {
var err error
DB, err = gorm.Open(sqlite.Open(cfg.DBPath), &gorm.Config{})
if err != nil {
return err
}
// Migrate the schema
DB.AutoMigrate(&Admin{}, &Group{}, &Link{})
return nil
}

9
database/errors.go Normal file
View file

@ -0,0 +1,9 @@
package database
import (
"errors"
)
var (
ErrWrongNumberOfAffectedRows = errors.New("wrong number of affected rows")
)

View file

@ -1,7 +1,106 @@
package database
type Group struct {
ID uint64 `gorm:"primaryKey"`
Name string `gorm:"unique,notNull"`
Links []Link `gorm:"constraint:OnDelete:CASCADE;"`
ID int
Name string
Links []Link
}
func GetGroupsWithLinks() ([]Group, error) {
query := `
SELECT id, name
FROM groups
ORDER BY groups.id
`
rows, err := DB.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
var groups []Group
for rows.Next() {
var group Group
if err := rows.Scan(&group.ID, &group.Name); err != nil {
return nil, err
}
groups = append(groups, group)
}
if err := rows.Err(); err != nil {
return nil, err
}
for i := range groups {
groups[i].Links, err = GetLinksFromGroup(groups[i].ID)
if err != nil {
return nil, err
}
}
return groups, nil
}
// Create a new group in the database.
// The function fills in the ID.
func CreateGroup(group *Group) error {
query := `
INSERT INTO groups (name)
VALUES (?)
RETURNING id
`
if err := DB.QueryRow(query, group.Name).Scan(&group.ID); err != nil {
return err
}
return nil
}
func UpdateGroup(id int, name string) error {
query := `
UPDATE groups
SET name = ?
WHERE id = ?
`
res, err := DB.Exec(query, name, id)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return ErrWrongNumberOfAffectedRows
}
return nil
}
func DeleteGroup(groupID int) error {
query := `
DELETE FROM groups
WHERE id = ?
`
res, err := DB.Exec(query, groupID)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return ErrWrongNumberOfAffectedRows
}
return nil
}

46
database/groups_test.go Normal file
View file

@ -0,0 +1,46 @@
package database
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func TestGroups(t *testing.T) {
initTestDatabase(t)
defer deleteTestDatabase(t)
// Create the first group.
group := Group{
Name: "test",
}
if err := CreateGroup(&group); err != nil {
t.Fatal(err)
}
if group.ID == 0 {
t.Fatal("group id is zero")
}
// Update group.
if err := UpdateGroup(group.ID, "new-name"); err != nil {
t.Fatal(err)
}
// Read groups.
groupList, err := GetGroupsWithLinks()
if err != nil {
t.Fatal(err)
}
if len(groupList) != 1 {
t.Fatal("group list length is not one")
}
if groupList[0].Name != "new-name" {
t.Fatal("wrong group name")
}
// Delete group.
if err := DeleteGroup(group.ID); err != nil {
t.Fatal(err)
}
}

View file

@ -1,9 +1,124 @@
package database
type Link struct {
ID uint64 `gorm:"primaryKey"`
Name string `gorm:"notNull"`
Href string `gorm:"notNull"`
GroupID uint64 `gorm:"notNull"`
ID int
Name string
Href string
GroupID int
Icon *string
}
func GetLinksFromGroup(groupID int) ([]Link, error) {
query := `
SELECT id, name, href, group_id, icon
FROM links
WHERE group_id = ?
ORDER BY id
`
rows, err := DB.Query(query, groupID)
if err != nil {
return nil, err
}
defer rows.Close()
var links []Link
for rows.Next() {
var link Link
if err := rows.Scan(&link.ID, &link.Name, &link.Href, &link.GroupID, &link.Icon); err != nil {
return nil, err
}
links = append(links, link)
}
if err := rows.Err(); err != nil {
return nil, err
}
return links, nil
}
func GetLink(id int) (*Link, error) {
query := `
SELECT id, name, href, group_id, icon
FROM links
WHERE id = ?
`
var link Link
err := DB.
QueryRow(query, id).
Scan(&link.ID, &link.Name, &link.Href, &link.GroupID, &link.Icon)
if err != nil {
return nil, err
}
return &link, nil
}
// Create a new link in the database.
// The function fills in the ID.
func CreateLink(link *Link) error {
query := `
INSERT INTO links (name, href, group_id, icon)
VALUES (?, ?, ?, ?)
RETURNING id
`
err := DB.
QueryRow(query, link.Name, link.Href, link.GroupID, link.Icon).
Scan(&link.ID)
if err != nil {
return err
}
return nil
}
func UpdateLink(link *Link) error {
query := `
UPDATE links
SET name = ?, href = ?, group_id = ?, icon = ?
WHERE id = ?
`
res, err := DB.Exec(query, link.Name, link.Href, link.GroupID, link.Icon, link.ID)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return ErrWrongNumberOfAffectedRows
}
return nil
}
func DeleteLink(linkID int) error {
query := `
DELETE FROM links
WHERE id = ?
`
res, err := DB.Exec(query, linkID)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return ErrWrongNumberOfAffectedRows
}
return nil
}

50
database/links_test.go Normal file
View file

@ -0,0 +1,50 @@
package database
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func TestLinks(t *testing.T) {
initTestDatabase(t)
defer deleteTestDatabase(t)
// Create the first group.
group := Group{
Name: "test",
}
if err := CreateGroup(&group); err != nil {
t.Fatal(err)
}
// Create the first link.
icon := "test/icon"
link := Link{
Name: "test",
Href: "/test",
GroupID: group.ID,
Icon: &icon,
}
if err := CreateLink(&link); err != nil {
t.Fatal(err)
}
if link.ID == 0 {
t.Fatal("link id is zero")
}
// Update link.
link.Href = "/new-href"
if err := UpdateLink(&link); err != nil {
t.Fatal(err)
}
// Delete link.
if err := DeleteLink(link.ID); err != nil {
t.Fatal(err)
}
// Delete group.
if err := DeleteGroup(group.ID); err != nil {
t.Fatal(err)
}
}

98
database/migrations.go Normal file
View file

@ -0,0 +1,98 @@
package database
import (
"database/sql"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
)
// List of migrations that should be applied.
// Migration ID = index + 1.
var migrations = []string{
`CREATE TABLE IF NOT EXISTS admins (
id INTEGER PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
bcrypt TEXT NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS groups (
id INTEGER PRIMARY KEY,
name TEXT
)`,
`CREATE TABLE IF NOT EXISTS links (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
href TEXT NOT NULL,
group_id INTEGER NOT NULL,
icon TEXT,
CONSTRAINT fk_groups_links
FOREIGN KEY (group_id)
REFERENCES groups(id)
ON DELETE CASCADE
)`,
}
func ApplyMigrations() error {
// Create a table to record applied migrations and retrieve the saved data.
_, err := DB.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER NOT NULL DEFAULT 0
)`)
if err != nil {
return err
}
var currentVersion int
err = DB.
QueryRow("SELECT version FROM migrations").
Scan(&currentVersion)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return err
}
// The table is empty, create a record.
_, err = DB.Exec("INSERT INTO migrations (version) VALUES (0)")
if err != nil {
return err
}
}
// Apply all migrations.
for i, migration := range migrations {
migrationID := i + 1
if migrationID <= currentVersion {
continue
}
if err := applyMigration(migrationID, migration); err != nil {
return fmt.Errorf("migration #%d: %w", migrationID, err)
}
log.Infof("Migration #%v has been applied", migrationID)
}
return nil
}
func applyMigration(migrationID int, query string) error {
tx, err := DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(query); err != nil {
return fmt.Errorf("error when applying migration: %w", err)
}
if _, err := tx.Exec("UPDATE migrations SET version = ?", migrationID); err != nil {
return fmt.Errorf("error when updating schema version: %w", err)
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,16 @@
package database
import (
_ "github.com/mattn/go-sqlite3"
"testing"
)
func TestMigrations(t *testing.T) {
initTestDatabase(t)
defer deleteTestDatabase(t)
// We should be able to call the function multiple times.
if err := ApplyMigrations(); err != nil {
t.Fatal(err)
}
}

6
go.mod
View file

@ -6,16 +6,12 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/joho/godotenv v1.5.1
github.com/kelseyhightower/envconfig v1.4.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/sirupsen/logrus v1.9.3
golang.org/x/crypto v0.21.0
gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
)
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/stretchr/testify v1.8.3 // indirect
golang.org/x/sys v0.18.0 // indirect
)

12
go.sum
View file

@ -3,16 +3,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
@ -30,7 +26,3 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

30
main.go
View file

@ -4,43 +4,55 @@ import (
"github.com/ordinary-dev/phoenix/config"
"github.com/ordinary-dev/phoenix/database"
"github.com/ordinary-dev/phoenix/views"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
func main() {
// Configure logger
logrus.SetFormatter(&logrus.TextFormatter{
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
// Read config
cfg, err := config.GetConfig()
if err != nil {
logrus.Fatalf("%v", err)
log.Fatal(err)
}
// Set log level
logLevel := cfg.GetLogLevel()
logrus.SetLevel(logLevel)
logrus.Infof("Setting log level to %v", logLevel)
log.SetLevel(logLevel)
log.Infof("Setting log level to %v", logLevel)
// Connect to the database
err = database.EstablishDatabaseConnection(cfg)
if err != nil {
logrus.Fatalf("%v", err)
log.Fatal(err)
}
// Apply migrations.
if err := database.ApplyMigrations(); err != nil {
log.Fatal(err)
}
// Create the first user
if cfg.DefaultUsername != "" && cfg.DefaultPassword != "" && database.CountAdmins() < 1 {
if cfg.DefaultUsername != "" && cfg.DefaultPassword != "" {
adminCount, err := database.CountAdmins()
if err != nil {
log.Fatal(err)
}
if adminCount < 1 {
_, err := database.CreateAdmin(cfg.DefaultUsername, cfg.DefaultPassword)
if err != nil {
logrus.Errorf("%v", err)
log.Fatal(err)
}
}
}
server, err := views.GetHttpServer()
if err != nil {
logrus.Fatal(err)
log.Fatal(err)
}
server.ListenAndServe()

View file

@ -57,7 +57,13 @@ func RequireAuth(next http.Handler) http.Handler {
// Most likely the user is not authorized.
if err != nil {
if database.CountAdmins() < 1 {
count, err := database.CountAdmins()
if err != nil {
pages.ShowError(w, http.StatusInternalServerError, err)
return
}
if count < 1 {
http.Redirect(w, r, "/registration", http.StatusFound)
} else {
http.Redirect(w, r, "/signin", http.StatusFound)

View file

@ -14,8 +14,8 @@ func CreateGroup(w http.ResponseWriter, r *http.Request) {
Name: r.FormValue("groupName"),
}
if result := database.DB.Create(&group); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
if err := database.CreateGroup(&group); err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
@ -24,37 +24,30 @@ func CreateGroup(w http.ResponseWriter, r *http.Request) {
}
func UpdateGroup(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
ShowError(w, http.StatusBadRequest, err)
return
}
var group database.Group
if result := database.DB.First(&group, id); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
return
}
group.Name = r.FormValue("groupName")
if result := database.DB.Save(&group); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
if err := database.UpdateGroup(int(id), r.FormValue("groupName")); err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
// This page is called from the settings, return the user back.
http.Redirect(w, r, fmt.Sprintf("/settings#group-%v", group.ID), http.StatusFound)
http.Redirect(w, r, fmt.Sprintf("/settings#group-%v", id), http.StatusFound)
}
func DeleteGroup(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
if err != nil {
ShowError(w, http.StatusBadRequest, err)
return
}
if result := database.DB.Delete(&database.Group{}, id); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
if err := database.DeleteGroup(int(id)); err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}

View file

@ -8,19 +8,13 @@ import (
)
func ShowMainPage(w http.ResponseWriter, _ *http.Request) {
// Get a list of groups with links
var groups []database.Group
result := database.DB.
Model(&database.Group{}).
Preload("Links").
Find(&groups)
if result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
groups, err := database.GetGroupsWithLinks()
if err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
err := Render("index.html.tmpl", w, map[string]any{
err = Render("index.html.tmpl", w, map[string]any{
"description": "Self-hosted start page.",
"groups": groups,
})

View file

@ -9,7 +9,7 @@ import (
)
func CreateLink(w http.ResponseWriter, r *http.Request) {
groupID, err := strconv.ParseUint(r.FormValue("groupID"), 10, 32)
groupID, err := strconv.Atoi(r.FormValue("groupID"))
if err != nil {
ShowError(w, http.StatusBadRequest, err)
return
@ -26,8 +26,8 @@ func CreateLink(w http.ResponseWriter, r *http.Request) {
} else {
link.Icon = &icon
}
if result := database.DB.Create(&link); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
if err := database.CreateLink(&link); err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
@ -36,15 +36,15 @@ func CreateLink(w http.ResponseWriter, r *http.Request) {
}
func UpdateLink(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
id, err := strconv.Atoi(r.PathValue("id"))
if err != nil {
ShowError(w, http.StatusBadRequest, err)
return
}
var link database.Link
if result := database.DB.First(&link, id); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
link, err := database.GetLink(id)
if err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
@ -56,8 +56,9 @@ func UpdateLink(w http.ResponseWriter, r *http.Request) {
} else {
link.Icon = &icon
}
if result := database.DB.Save(&link); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
if err := database.UpdateLink(link); err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
@ -66,14 +67,14 @@ func UpdateLink(w http.ResponseWriter, r *http.Request) {
}
func DeleteLink(w http.ResponseWriter, r *http.Request) {
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
id, err := strconv.Atoi(r.PathValue("id"))
if err != nil {
ShowError(w, http.StatusBadRequest, err)
return
}
if result := database.DB.Delete(&database.Link{}, id); result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
if err := database.DeleteLink(id); err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}

View file

@ -9,7 +9,13 @@ import (
)
func ShowRegistrationForm(w http.ResponseWriter, _ *http.Request) {
if database.CountAdmins() > 0 {
userCount, err := database.CountAdmins()
if err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
if userCount > 0 {
ShowError(w, http.StatusBadRequest, errors.New("at least 1 user already exists"))
return
}
@ -23,7 +29,13 @@ func ShowRegistrationForm(w http.ResponseWriter, _ *http.Request) {
}
func CreateUser(w http.ResponseWriter, r *http.Request) {
if database.CountAdmins() > 0 {
userCount, err := database.CountAdmins()
if err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}
if userCount > 0 {
ShowError(w, http.StatusBadRequest, errors.New("at least 1 user already exists"))
return
}
@ -31,7 +43,7 @@ func CreateUser(w http.ResponseWriter, r *http.Request) {
// Try to create a user.
username := r.FormValue("username")
password := r.FormValue("password")
_, err := database.CreateAdmin(username, password)
_, err = database.CreateAdmin(username, password)
if err != nil {
ShowError(w, http.StatusInternalServerError, err)
return

View file

@ -7,15 +7,9 @@ import (
)
func ShowSettings(w http.ResponseWriter, _ *http.Request) {
// Get a list of groups with links
var groups []database.Group
result := database.DB.
Model(&database.Group{}).
Preload("Links").
Find(&groups)
if result.Error != nil {
ShowError(w, http.StatusInternalServerError, result.Error)
groups, err := database.GetGroupsWithLinks()
if err != nil {
ShowError(w, http.StatusInternalServerError, err)
return
}

View file

@ -25,7 +25,7 @@ func AuthorizeUser(w http.ResponseWriter, r *http.Request) {
// Check credentials.
username := r.FormValue("username")
password := r.FormValue("password")
_, err := database.AuthorizeAdmin(username, password)
_, err := database.GetAdminIfPasswordMatches(username, password)
if err != nil {
ShowError(w, http.StatusUnauthorized, err)
return