Compare commits

...

8 Commits

Author SHA1 Message Date
7e27745bf7 added web_test
All checks were successful
continuous-integration/drone/push Build is passing
2022-06-03 19:56:29 +02:00
3de0aaeec8 fixed auth check
thats why we test
2022-06-03 19:56:07 +02:00
339dbe617f commented client functions 2022-06-03 14:28:00 +02:00
2f5cd563ad better error handling 2022-06-03 14:27:40 +02:00
2be2cc0626 minor fixes in webserver 2022-06-01 15:02:52 +02:00
22946b9dfd create slug now private 2022-06-01 11:56:15 +02:00
816d63ea49 better doc for interfaces 2022-06-01 11:55:57 +02:00
fc19fd867c added unit testing 2022-06-01 11:55:33 +02:00
15 changed files with 762 additions and 13 deletions

4
go.mod
View File

@@ -12,6 +12,7 @@ require (
require ( require (
github.com/alexflint/go-scalar v1.1.0 // indirect github.com/alexflint/go-scalar v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect
github.com/google/uuid v1.1.1 // indirect github.com/google/uuid v1.1.1 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect
@@ -24,11 +25,14 @@ require (
github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/xid v1.2.1 // indirect github.com/rs/xid v1.2.1 // indirect
github.com/smartystreets/assertions v1.13.0 // indirect github.com/smartystreets/assertions v1.13.0 // indirect
github.com/stretchr/testify v1.7.1 // indirect
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect
golang.org/x/text v0.3.3 // indirect golang.org/x/text v0.3.3 // indirect
gopkg.in/ini.v1 v1.57.0 // indirect gopkg.in/ini.v1 v1.57.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
) )

2
go.sum
View File

@@ -52,6 +52,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=

View File

@@ -27,6 +27,7 @@ func NewClient(db db.DB, s3 s3.S3) *Client {
} }
} }
// createRandomString creates a random string of length 6 to be used as a slug
func createRandomString() string { func createRandomString() string {
s := make([]rune, 6) s := make([]rune, 6)
for i := range s { for i := range s {
@@ -35,7 +36,8 @@ func createRandomString() string {
return string(s) return string(s)
} }
func (c *Client) CreateValidSlug(ctx context.Context) (string, error) { // createValidSlug creates a valid slug that is not yet in use
func (c *Client) createValidSlug(ctx context.Context) (string, error) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
slug := createRandomString() slug := createRandomString()
@@ -52,12 +54,14 @@ func (c *Client) CreateValidSlug(ctx context.Context) (string, error) {
return "", errors.New("could not create valid slug after 10 tries") return "", errors.New("could not create valid slug after 10 tries")
} }
// GetShare returns the share with the given slug, nil if not found
func (c *Client) GetShare(ctx context.Context, slug string) (*types.Share, error) { func (c *Client) GetShare(ctx context.Context, slug string) (*types.Share, error) {
return c.db.GetShare(ctx, slug) return c.db.GetShare(ctx, slug)
} }
// CreateShare creates a new share with the given key and returns the share with the slug
func (c *Client) CreateShare(ctx context.Context, key string) (*types.Share, error) { func (c *Client) CreateShare(ctx context.Context, key string) (*types.Share, error) {
slug, err := c.CreateValidSlug(ctx) slug, err := c.createValidSlug(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -73,7 +77,7 @@ func (c *Client) CreateShare(ctx context.Context, key string) (*types.Share, err
} }
if !exists { if !exists {
return nil, errors.New("key does not exist") return nil, types.ErrKeyNotFound
} }
err = c.db.CreateShare(ctx, share) err = c.db.CreateShare(ctx, share)
@@ -84,20 +88,27 @@ func (c *Client) CreateShare(ctx context.Context, key string) (*types.Share, err
return share, nil return share, nil
} }
// GetObjectFromShare returns the s3 object to the given share
func (c *Client) GetObjectFromShare(ctx context.Context, share *types.Share) (s3.ObjectReader, error) { func (c *Client) GetObjectFromShare(ctx context.Context, share *types.Share) (s3.ObjectReader, error) {
return c.s3.GetObject(ctx, share.Key) return c.s3.GetObject(ctx, share.Key)
} }
// DeleteShare deletes the share with the given slug
func (c *Client) DeleteShare(ctx context.Context, slug string) error { func (c *Client) DeleteShare(ctx context.Context, slug string) error {
return c.db.DeleteShare(ctx, slug) return c.db.DeleteShare(ctx, slug)
} }
// GetObjectMetadata returns the metadata of the object with the given key
func (c *Client) GetObjectMetadata(ctx context.Context, key string) (*types.Metadata, error) { func (c *Client) GetObjectMetadata(ctx context.Context, key string) (*types.Metadata, error) {
metadata, err := c.s3.GetObjectMetadata(ctx, key) metadata, err := c.s3.GetObjectMetadata(ctx, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if metadata == nil {
return nil, nil
}
if metadata.Filename == "" { if metadata.Filename == "" {
metadata.Filename = filepath.Base(key) metadata.Filename = filepath.Base(key)
} }
@@ -105,6 +116,7 @@ func (c *Client) GetObjectMetadata(ctx context.Context, key string) (*types.Meta
return metadata, nil return metadata, nil
} }
// GetAllShares returns all shares
func (c *Client) GetAllShares(ctx context.Context) ([]*types.Share, error) { func (c *Client) GetAllShares(ctx context.Context) ([]*types.Share, error) {
return c.db.GetAllShares(ctx) return c.db.GetAllShares(ctx)
} }

View File

@@ -0,0 +1,143 @@
package client_test
import (
"context"
"io/ioutil"
"testing"
"git.kapelle.org/niklas/s3share/internal/client"
"git.kapelle.org/niklas/s3share/internal/db"
"git.kapelle.org/niklas/s3share/internal/s3"
"git.kapelle.org/niklas/s3share/internal/types"
"github.com/stretchr/testify/assert"
)
func setup(t *testing.T) (*client.Client, context.Context, *assert.Assertions) {
mockDb := db.NewMock()
mockS3 := s3.NewMockS3()
service := client.NewClient(mockDb, mockS3)
ctx := context.Background()
assert := assert.New(t)
return service, ctx, assert
}
func TestCreateShare(t *testing.T) {
service, ctx, assert := setup(t)
share, err := service.CreateShare(ctx, "test.txt")
assert.NoError(err)
assert.NotNil(share)
assert.NotEmpty(share.Slug)
assert.Equal("test.txt", share.Key)
}
func TestCreateShareInvalidKey(t *testing.T) {
service, ctx, assert := setup(t)
_, err := service.CreateShare(ctx, "not_existing.txt")
assert.Error(err)
}
func TestGetShare(t *testing.T) {
service, ctx, assert := setup(t)
createdShare, err := service.CreateShare(ctx, "test.txt")
share, err := service.GetShare(ctx, createdShare.Slug)
assert.NoError(err)
assert.NotNil(share)
assert.Equal(createdShare.Slug, share.Slug)
assert.Equal("test.txt", share.Key)
}
func TestGetShareNotFound(t *testing.T) {
service, ctx, assert := setup(t)
share, err := service.GetShare(ctx, "123456")
assert.NoError(err)
assert.Nil(share)
}
func TestGetObjFromShare(t *testing.T) {
service, ctx, assert := setup(t)
createdShare, _ := service.CreateShare(ctx, "test.txt")
reader, err := service.GetObjectFromShare(ctx, createdShare)
assert.NoError(err)
assert.NotNil(reader)
content, err := ioutil.ReadAll(reader)
assert.NoError(err)
assert.Equal("test.txt", string(content))
}
func TestGetObjFromShareNotFound(t *testing.T) {
service, ctx, assert := setup(t)
_, err := service.GetObjectFromShare(ctx, &types.Share{Slug: "123456"})
assert.Error(err)
}
func TestDeleteShare(t *testing.T) {
service, ctx, assert := setup(t)
createdShare, _ := service.CreateShare(ctx, "test.txt")
err := service.DeleteShare(ctx, createdShare.Slug)
assert.NoError(err)
share, err := service.GetShare(ctx, createdShare.Slug)
assert.NoError(err)
assert.Nil(share)
}
func TestDeleteShareNotFound(t *testing.T) {
service, ctx, assert := setup(t)
err := service.DeleteShare(ctx, "123456")
assert.Error(err)
}
func TestGetMetadata(t *testing.T) {
service, ctx, assert := setup(t)
metadata, err := service.GetObjectMetadata(ctx, "test.txt")
assert.NoError(err)
assert.NotNil(metadata)
assert.Equal("test.txt", metadata.Filename)
assert.Equal("text/plain", metadata.ContentType)
assert.Equal(int64(8), metadata.Size)
}
func TestGetMetadataNotFound(t *testing.T) {
service, ctx, assert := setup(t)
metadata, err := service.GetObjectMetadata(ctx, "not_existing.txt")
assert.NoError(err)
assert.Nil(metadata)
}
func TestGetAllShares(t *testing.T) {
service, ctx, assert := setup(t)
share, err := service.CreateShare(ctx, "test.txt")
assert.NoError(err)
shares, err := service.GetAllShares(ctx)
assert.NoError(err)
assert.Len(shares, 1)
assert.Equal(share.Slug, shares[0].Slug)
assert.Equal(share.Key, shares[0].Key)
_, err = service.CreateShare(ctx, "dir/test")
assert.NoError(err)
shares, err = service.GetAllShares(ctx)
assert.NoError(err)
assert.Len(shares, 2)
}

View File

@@ -7,9 +7,18 @@ import (
) )
type DB interface { type DB interface {
// Return nil if share does not exist
GetShare(ctx context.Context, slug string) (*types.Share, error) GetShare(ctx context.Context, slug string) (*types.Share, error)
// Returns error if share already exists
CreateShare(ctx context.Context, share *types.Share) error CreateShare(ctx context.Context, share *types.Share) error
// Returns error if share does not exist
DeleteShare(ctx context.Context, slug string) error DeleteShare(ctx context.Context, slug string) error
// Returns all shares
GetAllShares(ctx context.Context) ([]*types.Share, error) GetAllShares(ctx context.Context) ([]*types.Share, error)
// Close the database
Close() error Close() error
} }

54
internal/db/mock.go Normal file
View File

@@ -0,0 +1,54 @@
package db
import (
"context"
"errors"
"git.kapelle.org/niklas/s3share/internal/types"
)
type mockDB struct {
shares map[string]*types.Share
}
func NewMock() DB {
return &mockDB{
shares: make(map[string]*types.Share),
}
}
func (d *mockDB) Close() error {
return nil
}
func (d *mockDB) CreateShare(ctx context.Context, share *types.Share) error {
if d.shares[share.Slug] != nil {
return errors.New("share already exists")
}
d.shares[share.Slug] = share
return nil
}
func (d *mockDB) DeleteShare(ctx context.Context, slug string) error {
if d.shares[slug] == nil {
return types.ErrShareNotFound
}
delete(d.shares, slug)
return nil
}
func (d *mockDB) GetAllShares(ctx context.Context) ([]*types.Share, error) {
// convert map to slice
shares := make([]*types.Share, 0, len(d.shares))
for _, share := range d.shares {
shares = append(shares, share)
}
return shares, nil
}
func (d *mockDB) GetShare(ctx context.Context, slug string) (*types.Share, error) {
if d.shares[slug] == nil {
return nil, nil
}
return d.shares[slug], nil
}

111
internal/db/mock_test.go Normal file
View File

@@ -0,0 +1,111 @@
package db_test
import (
"context"
"testing"
"git.kapelle.org/niklas/s3share/internal/db"
"git.kapelle.org/niklas/s3share/internal/types"
"github.com/stretchr/testify/assert"
)
func setup(t *testing.T) (db.DB, context.Context, *assert.Assertions) {
service := db.NewMock()
ctx := context.Background()
assert := assert.New(t)
return service, ctx, assert
}
func TestCreateShare(t *testing.T) {
service, ctx, assert := setup(t)
defer service.Close()
share := &types.Share{
Slug: "123456",
Key: "test.txt",
}
err := service.CreateShare(ctx, share)
assert.NoError(err)
assert.NotNil(service.GetShare(ctx, share.Slug))
}
func TestCreateShareDup(t *testing.T) {
service, ctx, assert := setup(t)
defer service.Close()
share := &types.Share{
Slug: "123456",
Key: "test.txt",
}
err := service.CreateShare(ctx, share)
assert.NoError(err)
err = service.CreateShare(ctx, share)
assert.Error(err)
}
func TestDeleteShare(t *testing.T) {
service, ctx, assert := setup(t)
defer service.Close()
share := &types.Share{
Slug: "123456",
Key: "test.txt",
}
err := service.CreateShare(ctx, share)
assert.NoError(err)
err = service.DeleteShare(ctx, share.Slug)
assert.NoError(err)
assert.Nil(service.GetShare(ctx, share.Slug))
}
func TestDeleteShareNotFound(t *testing.T) {
service, ctx, assert := setup(t)
defer service.Close()
share := &types.Share{
Slug: "123456",
Key: "test.txt",
}
err := service.DeleteShare(ctx, share.Slug)
assert.Error(err)
}
func TestGetAllShares(t *testing.T) {
service, ctx, assert := setup(t)
defer service.Close()
share := &types.Share{
Slug: "123456",
Key: "test.txt",
}
err := service.CreateShare(ctx, share)
assert.NoError(err)
shares, err := service.GetAllShares(ctx)
assert.NoError(err)
assert.Len(shares, 1)
assert.Equal(share.Slug, shares[0].Slug)
assert.Equal(share.Key, shares[0].Key)
// Create 2nd share
share2 := &types.Share{
Slug: "abcdef",
Key: "test2",
}
err = service.CreateShare(ctx, share2)
assert.NoError(err)
shares, err = service.GetAllShares(ctx)
assert.NoError(err)
assert.Len(shares, 2)
}

View File

@@ -79,10 +79,20 @@ func (db *sqlDB) CreateShare(ctx context.Context, share *types.Share) error {
} }
func (db *sqlDB) DeleteShare(ctx context.Context, slug string) error { func (db *sqlDB) DeleteShare(ctx context.Context, slug string) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM shares WHERE slug = ?", slug) result, err := db.db.ExecContext(ctx, "DELETE FROM shares WHERE slug = ?", slug)
if err != nil { if err != nil {
return err return err
} }
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return types.ErrShareNotFound
}
return nil return nil
} }

80
internal/s3/mock.go Normal file
View File

@@ -0,0 +1,80 @@
package s3
import (
"bytes"
"context"
"crypto/md5"
"fmt"
"git.kapelle.org/niklas/s3share/internal/types"
)
type mockS3 struct {
objects map[string]mockObject
}
type mockObject struct {
content []byte
contentType string
}
type mockObjectReader struct {
*bytes.Reader
}
func (r mockObjectReader) Close() error {
// NOOP
return nil
}
func NewMockS3() S3 {
return &mockS3{
objects: map[string]mockObject{
"test.txt": {
content: []byte("test.txt"),
contentType: "text/plain",
},
"test.png": {
content: []byte("test.png"),
contentType: "image/png",
},
"dir/test": {
content: []byte("test"),
contentType: "application/octet-stream",
},
},
}
}
func (m *mockS3) GetObject(ctx context.Context, key string) (ObjectReader, error) {
mockObj, exist := m.objects[key]
if !exist {
return nil, fmt.Errorf("Object not found")
}
reader := bytes.NewReader(mockObj.content)
return mockObjectReader{reader}, nil
}
func (m *mockS3) GetObjectMetadata(ctx context.Context, key string) (*types.Metadata, error) {
mockObj, exist := m.objects[key]
if !exist {
return nil, nil
}
return &types.Metadata{
Size: int64(len(mockObj.content)),
ETag: fmt.Sprintf("%x", md5.Sum(mockObj.content)),
ContentType: mockObj.contentType,
}, nil
}
func (m *mockS3) KeyExists(ctx context.Context, key string) (bool, error) {
_, exist := m.objects[key]
return exist, nil
}

72
internal/s3/mock_test.go Normal file
View File

@@ -0,0 +1,72 @@
package s3_test
import (
"context"
"io/ioutil"
"testing"
"git.kapelle.org/niklas/s3share/internal/s3"
"github.com/stretchr/testify/assert"
)
func setup(t *testing.T) (s3.S3, context.Context, *assert.Assertions) {
service := s3.NewMockS3()
ctx := context.Background()
assert := assert.New(t)
return service, ctx, assert
}
func TestGetObject(t *testing.T) {
service, ctx, assert := setup(t)
reader, err := service.GetObject(ctx, "test.txt")
assert.NoError(err)
assert.NotNil(reader)
content, err := ioutil.ReadAll(reader)
assert.NoError(err)
assert.Equal("test.txt", string(content))
}
func TestGetObjectNotFound(t *testing.T) {
service, ctx, assert := setup(t)
reader, err := service.GetObject(ctx, "not_existing.txt")
assert.Error(err)
assert.Nil(reader)
}
func TestGetMetadata(t *testing.T) {
service, ctx, assert := setup(t)
metadata, err := service.GetObjectMetadata(ctx, "test.txt")
assert.NoError(err)
assert.NotNil(metadata)
assert.Equal("text/plain", metadata.ContentType)
assert.Equal(int64(8), metadata.Size)
}
func TestGetMetadataNotFound(t *testing.T) {
service, ctx, assert := setup(t)
metadata, err := service.GetObjectMetadata(ctx, "not_existing.txt")
assert.NoError(err)
assert.Nil(metadata)
}
func TestKeyExists(t *testing.T) {
service, ctx, assert := setup(t)
exists, err := service.KeyExists(ctx, "test.txt")
assert.NoError(err)
assert.True(exists)
}
func TestKeyNotExist(t *testing.T) {
service, ctx, assert := setup(t)
exists, err := service.KeyExists(ctx, "not_existing.txt")
assert.NoError(err)
assert.False(exists)
}

View File

@@ -15,7 +15,12 @@ type ObjectReader interface {
} }
type S3 interface { type S3 interface {
// Get the object from the S3 bucket. Returns an error if the object does not exist.
GetObject(ctx context.Context, key string) (ObjectReader, error) GetObject(ctx context.Context, key string) (ObjectReader, error)
// Check if the given key exists
KeyExists(ctx context.Context, key string) (bool, error) KeyExists(ctx context.Context, key string) (bool, error)
// Get object metadata. The `Filename` field is optional. Returns nil if the object does not exist.
GetObjectMetadata(ctx context.Context, key string) (*types.Metadata, error) GetObjectMetadata(ctx context.Context, key string) (*types.Metadata, error)
} }

View File

@@ -23,7 +23,7 @@ func Start(config *types.AppConfig) {
client := client.NewClient(db, s3Client) client := client.NewClient(db, s3Client)
err = web.StartWebserver(config.Address, *client, config.APIUsername, config.APIPassword) err = web.StartWebserver(config.Address, client, config.APIUsername, config.APIPassword)
if err != nil { if err != nil {
logrus.Fatal(err.Error()) logrus.Fatal(err.Error())
} }

6
internal/types/errors.go Normal file
View File

@@ -0,0 +1,6 @@
package types
import "errors"
var ErrKeyNotFound = errors.New("Key not found")
var ErrShareNotFound = errors.New("Share not found")

View File

@@ -17,11 +17,17 @@ type createShare struct {
Key string `json:"key"` Key string `json:"key"`
} }
func StartWebserver(addr string, client client.Client, username, password string) error { func StartWebserver(addr string, client *client.Client, username, password string) error {
if username == "" || password == "" { if username == "" || password == "" {
return errors.New("API username and password must be set") return errors.New("API username and password must be set")
} }
r := CreateRouter(client, username, password)
return http.ListenAndServe(addr, r)
}
func CreateRouter(client *client.Client, username, password string) *mux.Router {
r := mux.NewRouter() r := mux.NewRouter()
r.HandleFunc("/{slug:[a-zA-Z0-9]{6}}", func(w http.ResponseWriter, r *http.Request) { r.HandleFunc("/{slug:[a-zA-Z0-9]{6}}", func(w http.ResponseWriter, r *http.Request) {
@@ -111,15 +117,19 @@ func StartWebserver(addr string, client client.Client, username, password string
share, err := client.CreateShare(r.Context(), shareParams.Key) share, err := client.CreateShare(r.Context(), shareParams.Key)
if err != nil { if err != nil {
if err == types.ErrKeyNotFound {
http.Error(w, "The specified key does not exist", http.StatusBadRequest)
return
}
logrus.Error(err.Error()) logrus.Error(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(share) json.NewEncoder(w).Encode(share)
}).Methods("POST") }).Methods("POST", "PUT")
r.HandleFunc("/api/share/{slug:[a-zA-Z0-9]{6}}", func(w http.ResponseWriter, r *http.Request) { r.HandleFunc("/api/share/{slug:[a-zA-Z0-9]{6}}", func(w http.ResponseWriter, r *http.Request) {
if !checkAuth(w, r, username, password) { if !checkAuth(w, r, username, password) {
@@ -130,19 +140,24 @@ func StartWebserver(addr string, client client.Client, username, password string
err := client.DeleteShare(r.Context(), vars["slug"]) err := client.DeleteShare(r.Context(), vars["slug"])
if err != nil { if err != nil {
if err == types.ErrShareNotFound {
http.NotFound(w, r)
return
}
logrus.Error(err.Error()) logrus.Error(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
w.WriteHeader(http.StatusNoContent)
}).Methods("DELETE") }).Methods("DELETE")
r.PathPrefix("/").Handler(http.FileServer(http.Dir("./public/"))) r.PathPrefix("/").Handler(http.FileServer(http.Dir("./public/")))
logrus.Info("Starting webserver") return r
return http.ListenAndServe(addr, r)
} }
func getShareHead(client client.Client, w http.ResponseWriter, r *http.Request) *types.Share { func getShareHead(client *client.Client, w http.ResponseWriter, r *http.Request) *types.Share {
vars := mux.Vars(r) vars := mux.Vars(r)
slug := vars["path"][0:6] slug := vars["path"][0:6]
share, err := client.GetShare(r.Context(), slug) share, err := client.GetShare(r.Context(), slug)
@@ -178,13 +193,13 @@ func getShareHead(client client.Client, w http.ResponseWriter, r *http.Request)
} }
func checkAuth(w http.ResponseWriter, r *http.Request, username, password string) bool { func checkAuth(w http.ResponseWriter, r *http.Request, username, password string) bool {
username, password, ok := r.BasicAuth() authUsername, authPassword, ok := r.BasicAuth()
if !ok { if !ok {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return false return false
} }
if username != username || password != password { if username != authUsername || password != authPassword {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return false return false
} }

226
internal/web/web_test.go Normal file
View File

@@ -0,0 +1,226 @@
package web_test
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"git.kapelle.org/niklas/s3share/internal/client"
"git.kapelle.org/niklas/s3share/internal/db"
"git.kapelle.org/niklas/s3share/internal/s3"
"git.kapelle.org/niklas/s3share/internal/web"
"github.com/stretchr/testify/assert"
)
func setup(t *testing.T) (*httptest.Server, *assert.Assertions) {
client := client.NewClient(db.NewMock(), s3.NewMockS3())
router := web.CreateRouter(client, "admin", "hunter2")
ts := httptest.NewServer(router)
assert := assert.New(t)
return ts, assert
}
func genCreateShareRequest(ts *httptest.Server, key string) *http.Request {
body := "{\"key\": \"" + key + "\"}"
req, _ := http.NewRequest("POST", ts.URL+"/api/share", bytes.NewReader([]byte(body)))
req.Header.Add("Authorization", "Basic YWRtaW46aHVudGVyMg==")
req.Header.Add("Content-Type", "application/json")
return req
}
func TestCreateShare(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req := genCreateShareRequest(ts, "test.txt")
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusCreated, res.StatusCode)
// check json response
var jsonResponse map[string]interface{}
err = json.NewDecoder(res.Body).Decode(&jsonResponse)
assert.NoError(err)
assert.Equal("test.txt", jsonResponse["key"])
assert.NotNil(jsonResponse["slug"])
assert.NotEmpty(jsonResponse["slug"])
}
func TestCreateShareInvalidKey(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req := genCreateShareRequest(ts, "not_existing.txt")
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusBadRequest, res.StatusCode)
}
func TestGetShare(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req := genCreateShareRequest(ts, "test.txt")
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusCreated, res.StatusCode)
var jsonResponse map[string]string
err = json.NewDecoder(res.Body).Decode(&jsonResponse)
assert.NoError(err)
req, err = http.NewRequest("GET", ts.URL+"/s/"+jsonResponse["slug"], nil)
assert.NoError(err)
res, err = http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusOK, res.StatusCode)
// check response
assert.Equal("inline; filename=\"test.txt\"", res.Header.Get("Content-Disposition"))
assert.Equal("text/plain", res.Header.Get("Content-Type"))
body, err := ioutil.ReadAll(res.Body)
assert.NoError(err)
assert.Equal("test.txt", string(body))
}
func TestGetShareInvalidSlug(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req, err := http.NewRequest("GET", ts.URL+"/s/123456", nil)
assert.NoError(err)
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusNotFound, res.StatusCode)
}
func TestGetShareFileExt(t *testing.T) {
// Basically the same as TestGetShare, but with a file extension in the slug
ts, assert := setup(t)
defer ts.Close()
req := genCreateShareRequest(ts, "test.txt")
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusCreated, res.StatusCode)
var jsonResponse map[string]string
err = json.NewDecoder(res.Body).Decode(&jsonResponse)
assert.NoError(err)
req, err = http.NewRequest("GET", ts.URL+"/s/"+jsonResponse["slug"]+".txt", nil)
assert.NoError(err)
res, err = http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusOK, res.StatusCode)
// check response
assert.Equal("inline; filename=\"test.txt\"", res.Header.Get("Content-Disposition"))
assert.Equal("text/plain", res.Header.Get("Content-Type"))
body, err := ioutil.ReadAll(res.Body)
assert.NoError(err)
assert.Equal("test.txt", string(body))
}
func TestDeleteShare(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req := genCreateShareRequest(ts, "test.txt")
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusCreated, res.StatusCode)
var jsonResponse map[string]string
err = json.NewDecoder(res.Body).Decode(&jsonResponse)
assert.NoError(err)
req, err = http.NewRequest("DELETE", ts.URL+"/api/share/"+jsonResponse["slug"], nil)
req.Header.Add("Authorization", "Basic YWRtaW46aHVudGVyMg==")
assert.NoError(err)
res, err = http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusNoContent, res.StatusCode)
req, err = http.NewRequest("GET", ts.URL+"/s/"+jsonResponse["slug"], nil)
assert.NoError(err)
res, err = http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusNotFound, res.StatusCode)
}
func TestDeleteShareInvalidSlug(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req, err := http.NewRequest("DELETE", ts.URL+"/api/share/123456", nil)
req.Header.Add("Authorization", "Basic YWRtaW46aHVudGVyMg==")
assert.NoError(err)
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusNotFound, res.StatusCode)
}
func TestGetAll(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req := genCreateShareRequest(ts, "test.txt")
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusCreated, res.StatusCode)
var jsonResponse map[string]string
err = json.NewDecoder(res.Body).Decode(&jsonResponse)
assert.NoError(err)
req, err = http.NewRequest("GET", ts.URL+"/api/share", nil)
req.Header.Add("Authorization", "Basic YWRtaW46aHVudGVyMg==")
assert.NoError(err)
res, err = http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusOK, res.StatusCode)
// check response
var jsonResponse2 []map[string]string
err = json.NewDecoder(res.Body).Decode(&jsonResponse2)
assert.NoError(err)
assert.Equal(1, len(jsonResponse2))
assert.Equal(jsonResponse["slug"], jsonResponse2[0]["slug"])
}
func TestInvalidAuth(t *testing.T) {
ts, assert := setup(t)
defer ts.Close()
req, err := http.NewRequest("GET", ts.URL+"/api/share", nil)
req.Header.Add("Authorization", "Basic YWRtaW46aHVudGVyMw==")
assert.NoError(err)
res, err := http.DefaultClient.Do(req)
assert.NoError(err)
assert.Equal(http.StatusUnauthorized, res.StatusCode)
}