mirror of
https://github.com/ordinary-dev/phoenix
synced 2024-09-19 19:30:28 +05:00
chore: migrate to database/sql
Since I started simplifying, I decided to abandon ORM. I won’t say that this makes much sense, everything works more or less as before. Except that the size of the program has decreased slightly again, by about a megabyte.
This commit is contained in:
parent
66701b2687
commit
5aa2cee5b1
|
@ -5,49 +5,91 @@ import (
|
|||
)
|
||||
|
||||
type Admin struct {
|
||||
ID uint64 `gorm:"primaryKey"`
|
||||
Username string `gorm:"unique;notNull"`
|
||||
Bcrypt string `gorm:"notNull"`
|
||||
ID int
|
||||
Username string
|
||||
Bcrypt string
|
||||
}
|
||||
|
||||
func CountAdmins() int64 {
|
||||
var admins []Admin
|
||||
func CountAdmins() (int64, error) {
|
||||
var count int64
|
||||
DB.Model(&admins).Count(&count)
|
||||
return count
|
||||
query := `SELECT COUNT(*) FROM admins`
|
||||
if err := DB.QueryRow(query).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func CreateAdmin(username string, password string) (Admin, error) {
|
||||
func CreateAdmin(username string, password string) (*Admin, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), 10)
|
||||
if err != nil {
|
||||
return Admin{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
admin := Admin{
|
||||
Username: username,
|
||||
Bcrypt: string(hash),
|
||||
}
|
||||
result := DB.Create(&admin)
|
||||
query := `
|
||||
INSERT INTO admins(username, bcrypt)
|
||||
VALUES (?, ?)
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
if result.Error != nil {
|
||||
return Admin{}, result.Error
|
||||
}
|
||||
|
||||
return admin, nil
|
||||
}
|
||||
|
||||
func AuthorizeAdmin(username string, password string) (Admin, error) {
|
||||
var admin Admin
|
||||
result := DB.Where("username = ?", username).First(&admin)
|
||||
admin.Username = username
|
||||
admin.Bcrypt = string(hash)
|
||||
|
||||
if result.Error != nil {
|
||||
return Admin{}, result.Error
|
||||
}
|
||||
err = DB.
|
||||
QueryRow(query, admin.Username, admin.Bcrypt).
|
||||
Scan(&admin.ID)
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(admin.Bcrypt), []byte(password))
|
||||
if err != nil {
|
||||
return Admin{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return admin, nil
|
||||
return &admin, nil
|
||||
}
|
||||
|
||||
func GetAdminIfPasswordMatches(username string, password string) (*Admin, error) {
|
||||
query := `
|
||||
SELECT id, username, bcrypt
|
||||
FROM admins
|
||||
WHERE username = ?
|
||||
`
|
||||
|
||||
var admin Admin
|
||||
err := DB.
|
||||
QueryRow(query, username).
|
||||
Scan(&admin.ID, &admin.Username, &admin.Bcrypt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(admin.Bcrypt), []byte(password))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &admin, nil
|
||||
}
|
||||
|
||||
func DeleteAdmin(id int) error {
|
||||
query := `
|
||||
DELETE FROM admins
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
res, err := DB.Exec(query, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected != 1 {
|
||||
return ErrWrongNumberOfAffectedRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
58
database/admins_test.go
Normal file
58
database/admins_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdmins(t *testing.T) {
|
||||
initTestDatabase(t)
|
||||
defer deleteTestDatabase(t)
|
||||
|
||||
// We should have no admins.
|
||||
count, err := CountAdmins()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if count != 0 {
|
||||
t.Fatal("user count is not zero")
|
||||
}
|
||||
|
||||
// Create the first user.
|
||||
username := "test"
|
||||
password := "test"
|
||||
admin, err := CreateAdmin(username, password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check password and get admin.
|
||||
dbAdmin, err := GetAdminIfPasswordMatches(username, password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if dbAdmin.ID != admin.ID {
|
||||
t.Fatal("wrong admin id")
|
||||
}
|
||||
|
||||
// Check wrong password handling.
|
||||
if _, err := GetAdminIfPasswordMatches("test", "wrong-password"); err == nil {
|
||||
t.Fatal("wrong password was accepted")
|
||||
}
|
||||
|
||||
// Count users again.
|
||||
count, err = CountAdmins()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Fatal("user count is not one")
|
||||
}
|
||||
|
||||
// Delete user.
|
||||
if err := DeleteAdmin(admin.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
20
database/connection.go
Normal file
20
database/connection.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/ordinary-dev/phoenix/config"
|
||||
)
|
||||
|
||||
var DB *sql.DB
|
||||
|
||||
func EstablishDatabaseConnection(cfg *config.Config) error {
|
||||
var err error
|
||||
DB, err = sql.Open("sqlite3", cfg.DBPath)
|
||||
|
||||
if err := DB.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
28
database/connection_test.go
Normal file
28
database/connection_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const TEST_DB_PATH = "/tmp/phoenix.sqlite3"
|
||||
|
||||
func initTestDatabase(t *testing.T) {
|
||||
var err error
|
||||
DB, err = sql.Open("sqlite3", TEST_DB_PATH)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := ApplyMigrations(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteTestDatabase(t *testing.T) {
|
||||
if err := os.Remove(TEST_DB_PATH); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"github.com/ordinary-dev/phoenix/config"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func EstablishDatabaseConnection(cfg *config.Config) error {
|
||||
var err error
|
||||
DB, err = gorm.Open(sqlite.Open(cfg.DBPath), &gorm.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Migrate the schema
|
||||
DB.AutoMigrate(&Admin{}, &Group{}, &Link{})
|
||||
|
||||
return nil
|
||||
}
|
9
database/errors.go
Normal file
9
database/errors.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWrongNumberOfAffectedRows = errors.New("wrong number of affected rows")
|
||||
)
|
|
@ -1,7 +1,106 @@
|
|||
package database
|
||||
|
||||
type Group struct {
|
||||
ID uint64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"unique,notNull"`
|
||||
Links []Link `gorm:"constraint:OnDelete:CASCADE;"`
|
||||
ID int
|
||||
Name string
|
||||
Links []Link
|
||||
}
|
||||
|
||||
func GetGroupsWithLinks() ([]Group, error) {
|
||||
query := `
|
||||
SELECT id, name
|
||||
FROM groups
|
||||
ORDER BY groups.id
|
||||
`
|
||||
|
||||
rows, err := DB.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var groups []Group
|
||||
for rows.Next() {
|
||||
var group Group
|
||||
if err := rows.Scan(&group.ID, &group.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groups = append(groups, group)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range groups {
|
||||
groups[i].Links, err = GetLinksFromGroup(groups[i].ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// Create a new group in the database.
|
||||
// The function fills in the ID.
|
||||
func CreateGroup(group *Group) error {
|
||||
query := `
|
||||
INSERT INTO groups (name)
|
||||
VALUES (?)
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
if err := DB.QueryRow(query, group.Name).Scan(&group.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateGroup(id int, name string) error {
|
||||
query := `
|
||||
UPDATE groups
|
||||
SET name = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
res, err := DB.Exec(query, name, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected != 1 {
|
||||
return ErrWrongNumberOfAffectedRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteGroup(groupID int) error {
|
||||
query := `
|
||||
DELETE FROM groups
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
res, err := DB.Exec(query, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected != 1 {
|
||||
return ErrWrongNumberOfAffectedRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
46
database/groups_test.go
Normal file
46
database/groups_test.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGroups(t *testing.T) {
|
||||
initTestDatabase(t)
|
||||
defer deleteTestDatabase(t)
|
||||
|
||||
// Create the first group.
|
||||
group := Group{
|
||||
Name: "test",
|
||||
}
|
||||
if err := CreateGroup(&group); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if group.ID == 0 {
|
||||
t.Fatal("group id is zero")
|
||||
}
|
||||
|
||||
// Update group.
|
||||
if err := UpdateGroup(group.ID, "new-name"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read groups.
|
||||
groupList, err := GetGroupsWithLinks()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(groupList) != 1 {
|
||||
t.Fatal("group list length is not one")
|
||||
}
|
||||
|
||||
if groupList[0].Name != "new-name" {
|
||||
t.Fatal("wrong group name")
|
||||
}
|
||||
|
||||
// Delete group.
|
||||
if err := DeleteGroup(group.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -1,9 +1,124 @@
|
|||
package database
|
||||
|
||||
type Link struct {
|
||||
ID uint64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"notNull"`
|
||||
Href string `gorm:"notNull"`
|
||||
GroupID uint64 `gorm:"notNull"`
|
||||
ID int
|
||||
Name string
|
||||
Href string
|
||||
GroupID int
|
||||
Icon *string
|
||||
}
|
||||
|
||||
func GetLinksFromGroup(groupID int) ([]Link, error) {
|
||||
query := `
|
||||
SELECT id, name, href, group_id, icon
|
||||
FROM links
|
||||
WHERE group_id = ?
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := DB.Query(query, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var links []Link
|
||||
for rows.Next() {
|
||||
var link Link
|
||||
if err := rows.Scan(&link.ID, &link.Name, &link.Href, &link.GroupID, &link.Icon); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
links = append(links, link)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return links, nil
|
||||
}
|
||||
|
||||
func GetLink(id int) (*Link, error) {
|
||||
query := `
|
||||
SELECT id, name, href, group_id, icon
|
||||
FROM links
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
var link Link
|
||||
err := DB.
|
||||
QueryRow(query, id).
|
||||
Scan(&link.ID, &link.Name, &link.Href, &link.GroupID, &link.Icon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &link, nil
|
||||
}
|
||||
|
||||
// Create a new link in the database.
|
||||
// The function fills in the ID.
|
||||
func CreateLink(link *Link) error {
|
||||
query := `
|
||||
INSERT INTO links (name, href, group_id, icon)
|
||||
VALUES (?, ?, ?, ?)
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
err := DB.
|
||||
QueryRow(query, link.Name, link.Href, link.GroupID, link.Icon).
|
||||
Scan(&link.ID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateLink(link *Link) error {
|
||||
query := `
|
||||
UPDATE links
|
||||
SET name = ?, href = ?, group_id = ?, icon = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
res, err := DB.Exec(query, link.Name, link.Href, link.GroupID, link.Icon, link.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected != 1 {
|
||||
return ErrWrongNumberOfAffectedRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteLink(linkID int) error {
|
||||
query := `
|
||||
DELETE FROM links
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
res, err := DB.Exec(query, linkID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rowsAffected != 1 {
|
||||
return ErrWrongNumberOfAffectedRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
50
database/links_test.go
Normal file
50
database/links_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLinks(t *testing.T) {
|
||||
initTestDatabase(t)
|
||||
defer deleteTestDatabase(t)
|
||||
|
||||
// Create the first group.
|
||||
group := Group{
|
||||
Name: "test",
|
||||
}
|
||||
if err := CreateGroup(&group); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the first link.
|
||||
icon := "test/icon"
|
||||
link := Link{
|
||||
Name: "test",
|
||||
Href: "/test",
|
||||
GroupID: group.ID,
|
||||
Icon: &icon,
|
||||
}
|
||||
if err := CreateLink(&link); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if link.ID == 0 {
|
||||
t.Fatal("link id is zero")
|
||||
}
|
||||
|
||||
// Update link.
|
||||
link.Href = "/new-href"
|
||||
if err := UpdateLink(&link); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete link.
|
||||
if err := DeleteLink(link.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete group.
|
||||
if err := DeleteGroup(group.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
98
database/migrations.go
Normal file
98
database/migrations.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// List of migrations that should be applied.
|
||||
// Migration ID = index + 1.
|
||||
var migrations = []string{
|
||||
`CREATE TABLE IF NOT EXISTS admins (
|
||||
id INTEGER PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
bcrypt TEXT NOT NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS groups (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS links (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
href TEXT NOT NULL,
|
||||
group_id INTEGER NOT NULL,
|
||||
icon TEXT,
|
||||
CONSTRAINT fk_groups_links
|
||||
FOREIGN KEY (group_id)
|
||||
REFERENCES groups(id)
|
||||
ON DELETE CASCADE
|
||||
)`,
|
||||
}
|
||||
|
||||
func ApplyMigrations() error {
|
||||
// Create a table to record applied migrations and retrieve the saved data.
|
||||
_, err := DB.Exec(`CREATE TABLE IF NOT EXISTS migrations (
|
||||
version INTEGER NOT NULL DEFAULT 0
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var currentVersion int
|
||||
err = DB.
|
||||
QueryRow("SELECT version FROM migrations").
|
||||
Scan(¤tVersion)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
// The table is empty, create a record.
|
||||
_, err = DB.Exec("INSERT INTO migrations (version) VALUES (0)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Apply all migrations.
|
||||
for i, migration := range migrations {
|
||||
migrationID := i + 1
|
||||
if migrationID <= currentVersion {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := applyMigration(migrationID, migration); err != nil {
|
||||
return fmt.Errorf("migration #%d: %w", migrationID, err)
|
||||
}
|
||||
|
||||
log.Infof("Migration #%v has been applied", migrationID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyMigration(migrationID int, query string) error {
|
||||
tx, err := DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
return fmt.Errorf("error when applying migration: %w", err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec("UPDATE migrations SET version = ?", migrationID); err != nil {
|
||||
return fmt.Errorf("error when updating schema version: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
16
database/migrations_test.go
Normal file
16
database/migrations_test.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigrations(t *testing.T) {
|
||||
initTestDatabase(t)
|
||||
defer deleteTestDatabase(t)
|
||||
|
||||
// We should be able to call the function multiple times.
|
||||
if err := ApplyMigrations(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
6
go.mod
6
go.mod
|
@ -6,16 +6,12 @@ require (
|
|||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/kelseyhightower/envconfig v1.4.0
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
golang.org/x/crypto v0.21.0
|
||||
gorm.io/driver/sqlite v1.5.5
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/stretchr/testify v1.8.3 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
)
|
||||
|
|
12
go.sum
12
go.sum
|
@ -3,16 +3,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
|
||||
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
|
@ -30,7 +26,3 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
|
|||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
|
||||
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg=
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
|
|
30
main.go
30
main.go
|
@ -4,43 +4,55 @@ import (
|
|||
"github.com/ordinary-dev/phoenix/config"
|
||||
"github.com/ordinary-dev/phoenix/database"
|
||||
"github.com/ordinary-dev/phoenix/views"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Configure logger
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
})
|
||||
|
||||
// Read config
|
||||
cfg, err := config.GetConfig()
|
||||
if err != nil {
|
||||
logrus.Fatalf("%v", err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Set log level
|
||||
logLevel := cfg.GetLogLevel()
|
||||
logrus.SetLevel(logLevel)
|
||||
logrus.Infof("Setting log level to %v", logLevel)
|
||||
log.SetLevel(logLevel)
|
||||
log.Infof("Setting log level to %v", logLevel)
|
||||
|
||||
// Connect to the database
|
||||
err = database.EstablishDatabaseConnection(cfg)
|
||||
if err != nil {
|
||||
logrus.Fatalf("%v", err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Apply migrations.
|
||||
if err := database.ApplyMigrations(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the first user
|
||||
if cfg.DefaultUsername != "" && cfg.DefaultPassword != "" && database.CountAdmins() < 1 {
|
||||
if cfg.DefaultUsername != "" && cfg.DefaultPassword != "" {
|
||||
adminCount, err := database.CountAdmins()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if adminCount < 1 {
|
||||
_, err := database.CreateAdmin(cfg.DefaultUsername, cfg.DefaultPassword)
|
||||
if err != nil {
|
||||
logrus.Errorf("%v", err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
server, err := views.GetHttpServer()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
server.ListenAndServe()
|
||||
|
|
|
@ -57,7 +57,13 @@ func RequireAuth(next http.Handler) http.Handler {
|
|||
|
||||
// Most likely the user is not authorized.
|
||||
if err != nil {
|
||||
if database.CountAdmins() < 1 {
|
||||
count, err := database.CountAdmins()
|
||||
if err != nil {
|
||||
pages.ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
if count < 1 {
|
||||
http.Redirect(w, r, "/registration", http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, "/signin", http.StatusFound)
|
||||
|
|
|
@ -14,8 +14,8 @@ func CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||
Name: r.FormValue("groupName"),
|
||||
}
|
||||
|
||||
if result := database.DB.Create(&group); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
if err := database.CreateGroup(&group); err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -24,37 +24,30 @@ func CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
|
||||
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
var group database.Group
|
||||
if result := database.DB.First(&group, id); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
return
|
||||
}
|
||||
|
||||
group.Name = r.FormValue("groupName")
|
||||
if result := database.DB.Save(&group); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
if err := database.UpdateGroup(int(id), r.FormValue("groupName")); err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
// This page is called from the settings, return the user back.
|
||||
http.Redirect(w, r, fmt.Sprintf("/settings#group-%v", group.ID), http.StatusFound)
|
||||
http.Redirect(w, r, fmt.Sprintf("/settings#group-%v", id), http.StatusFound)
|
||||
}
|
||||
|
||||
func DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
|
||||
id, err := strconv.ParseInt(r.PathValue("id"), 10, 64)
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
if result := database.DB.Delete(&database.Group{}, id); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
if err := database.DeleteGroup(int(id)); err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -8,19 +8,13 @@ import (
|
|||
)
|
||||
|
||||
func ShowMainPage(w http.ResponseWriter, _ *http.Request) {
|
||||
// Get a list of groups with links
|
||||
var groups []database.Group
|
||||
result := database.DB.
|
||||
Model(&database.Group{}).
|
||||
Preload("Links").
|
||||
Find(&groups)
|
||||
|
||||
if result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
groups, err := database.GetGroupsWithLinks()
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
err := Render("index.html.tmpl", w, map[string]any{
|
||||
err = Render("index.html.tmpl", w, map[string]any{
|
||||
"description": "Self-hosted start page.",
|
||||
"groups": groups,
|
||||
})
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func CreateLink(w http.ResponseWriter, r *http.Request) {
|
||||
groupID, err := strconv.ParseUint(r.FormValue("groupID"), 10, 32)
|
||||
groupID, err := strconv.Atoi(r.FormValue("groupID"))
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
|
@ -26,8 +26,8 @@ func CreateLink(w http.ResponseWriter, r *http.Request) {
|
|||
} else {
|
||||
link.Icon = &icon
|
||||
}
|
||||
if result := database.DB.Create(&link); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
if err := database.CreateLink(&link); err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -36,15 +36,15 @@ func CreateLink(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func UpdateLink(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
var link database.Link
|
||||
if result := database.DB.First(&link, id); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
link, err := database.GetLink(id)
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -56,8 +56,9 @@ func UpdateLink(w http.ResponseWriter, r *http.Request) {
|
|||
} else {
|
||||
link.Icon = &icon
|
||||
}
|
||||
if result := database.DB.Save(&link); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
|
||||
if err := database.UpdateLink(link); err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -66,14 +67,14 @@ func UpdateLink(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func DeleteLink(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
if result := database.DB.Delete(&database.Link{}, id); result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
if err := database.DeleteLink(id); err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,13 @@ import (
|
|||
)
|
||||
|
||||
func ShowRegistrationForm(w http.ResponseWriter, _ *http.Request) {
|
||||
if database.CountAdmins() > 0 {
|
||||
userCount, err := database.CountAdmins()
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
if userCount > 0 {
|
||||
ShowError(w, http.StatusBadRequest, errors.New("at least 1 user already exists"))
|
||||
return
|
||||
}
|
||||
|
@ -23,7 +29,13 @@ func ShowRegistrationForm(w http.ResponseWriter, _ *http.Request) {
|
|||
}
|
||||
|
||||
func CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if database.CountAdmins() > 0 {
|
||||
userCount, err := database.CountAdmins()
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
if userCount > 0 {
|
||||
ShowError(w, http.StatusBadRequest, errors.New("at least 1 user already exists"))
|
||||
return
|
||||
}
|
||||
|
@ -31,7 +43,7 @@ func CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||
// Try to create a user.
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
_, err := database.CreateAdmin(username, password)
|
||||
_, err = database.CreateAdmin(username, password)
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
|
|
|
@ -7,15 +7,9 @@ import (
|
|||
)
|
||||
|
||||
func ShowSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
// Get a list of groups with links
|
||||
var groups []database.Group
|
||||
result := database.DB.
|
||||
Model(&database.Group{}).
|
||||
Preload("Links").
|
||||
Find(&groups)
|
||||
|
||||
if result.Error != nil {
|
||||
ShowError(w, http.StatusInternalServerError, result.Error)
|
||||
groups, err := database.GetGroupsWithLinks()
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func AuthorizeUser(w http.ResponseWriter, r *http.Request) {
|
|||
// Check credentials.
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
_, err := database.AuthorizeAdmin(username, password)
|
||||
_, err := database.GetAdminIfPasswordMatches(username, password)
|
||||
if err != nil {
|
||||
ShowError(w, http.StatusUnauthorized, err)
|
||||
return
|
||||
|
|
Loading…
Reference in a new issue