diff --git a/internal/db/db.go b/internal/db/db.go index d1b7810..93da0c2 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,94 +1,9 @@ package db -import ( - "context" - "database/sql" - _ "embed" - "time" +import "context" - _ "github.com/go-sql-driver/mysql" - "golang.org/x/crypto/bcrypt" -) - -//go:embed setup.sql -var setupSql string - -const DB_NAME = "s3Browser" - -type DB struct { - dbConn *sql.DB -} - -func NewDB(dataSourceName string) (*DB, error) { - db, err := sql.Open("mysql", dataSourceName) - - if err != nil { - return nil, err - } - - db.SetConnMaxLifetime(time.Minute * 3) - db.SetMaxOpenConns(10) - db.SetMaxIdleConns(10) - - return &DB{ - dbConn: db, - }, nil -} - -func (d *DB) Setup() error { - tx, err := d.dbConn.Begin() - if err != nil { - return err - } - - _, err = tx.Exec(setupSql) - if err != nil { - tx.Rollback() - return err - } - - err = tx.Commit() - if err != nil { - return err - } - - return nil -} - -func (d *DB) CheckLogin(ctx context.Context, username, password string) (bool, error) { - rows, err := d.dbConn.QueryContext(ctx, "SELECT password FROM user WHERE username = ?", username) - if err != nil { - return false, err - } - - if !rows.Next() { - return false, nil - } - - var passwordHash []byte - err = rows.Scan(&passwordHash) - if err != nil { - return false, err - } - - if bcrypt.CompareHashAndPassword(passwordHash, []byte(password)) != nil { - return false, nil - } - - return true, nil -} - -func (d *DB) AddUser(ctx context.Context, username, password string) error { - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return err - } - - _, err = d.dbConn.ExecContext(ctx, "INSERT INTO user (username,password) VALUES (?,?)", username, hash) - - if err != nil { - return err - } - - return nil +type DB interface { + Setup() error + CheckLogin(ctx context.Context, username, password string) (bool, error) + AddUser(ctx context.Context, username, password string) error } diff --git a/internal/db/mysql.go b/internal/db/mysql.go new file mode 100644 index 0000000..2df33e2 --- /dev/null +++ b/internal/db/mysql.go @@ -0,0 +1,94 @@ +package db + +import ( + "context" + "database/sql" + _ "embed" + "time" + + _ "github.com/go-sql-driver/mysql" + "golang.org/x/crypto/bcrypt" +) + +//go:embed setup.sql +var setupSql string + +const DB_NAME = "s3Browser" + +type mysqlDB struct { + dbConn *sql.DB +} + +func NewDB(dataSourceName string) (DB, error) { + db, err := sql.Open("mysql", dataSourceName) + + if err != nil { + return nil, err + } + + db.SetConnMaxLifetime(time.Minute * 3) + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(10) + + return &mysqlDB{ + dbConn: db, + }, nil +} + +func (d *mysqlDB) Setup() error { + tx, err := d.dbConn.Begin() + if err != nil { + return err + } + + _, err = tx.Exec(setupSql) + if err != nil { + tx.Rollback() + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (d *mysqlDB) CheckLogin(ctx context.Context, username, password string) (bool, error) { + rows, err := d.dbConn.QueryContext(ctx, "SELECT password FROM user WHERE username = ?", username) + if err != nil { + return false, err + } + + if !rows.Next() { + return false, nil + } + + var passwordHash []byte + err = rows.Scan(&passwordHash) + if err != nil { + return false, err + } + + if bcrypt.CompareHashAndPassword(passwordHash, []byte(password)) != nil { + return false, nil + } + + return true, nil +} + +func (d *mysqlDB) AddUser(ctx context.Context, username, password string) error { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + + _, err = d.dbConn.ExecContext(ctx, "INSERT INTO user (username,password) VALUES (?,?)", username, hash) + + if err != nil { + return err + } + + return nil +}