Skip to content

Commit

Permalink
allow user to disable registration using -closed
Browse files Browse the repository at this point in the history
  • Loading branch information
dimkr committed Dec 17, 2023
1 parent 2dc60ef commit 9537395
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 12 deletions.
3 changes: 2 additions & 1 deletion cmd/tootik/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ var (
key = flag.String("key", "key.pem", "HTTPS TLS key")
addr = flag.String("addr", ":8443", "HTTPS listening address")
blockListPath = flag.String("blocklist", "", "Blocklist CSV")
closed = flag.Bool("closed", false, "Disable new user registration")
version = flag.Bool("version", false, "Print version and exit")
)

Expand Down Expand Up @@ -139,7 +140,7 @@ func main() {
wg.Done()
}()

handler := front.NewHandler()
handler := front.NewHandler(*closed)

wg.Add(1)
go func() {
Expand Down
10 changes: 8 additions & 2 deletions front/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,20 @@ func serveStaticFile(w text.Writer, r *request) {
}
}

func NewHandler() Handler {
func NewHandler(closed bool) Handler {
h := Handler{}
var cache sync.Map

h[regexp.MustCompile(`^/$`)] = withUserMenu(home)

h[regexp.MustCompile(`^/users$`)] = withUserMenu(users)
h[regexp.MustCompile(`^/users/register$`)] = register
if closed {
h[regexp.MustCompile(`^/users/register$`)] = func(w text.Writer, r *request) {
w.Status(40, "Registration is closed")
}
} else {
h[regexp.MustCompile(`^/users/register$`)] = register
}

h[regexp.MustCompile(`^/users/inbox/[0-9]{4}-[0-9]{2}-[0-9]{2}$`)] = withUserMenu(byDate)
h[regexp.MustCompile(`^/users/inbox/today$`)] = withUserMenu(today)
Expand Down
87 changes: 79 additions & 8 deletions test/register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func TestRegister_Redirect(t *testing.T) {
_, err = tlsReader.Write([]byte("https://localhost.localdomain/users\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down Expand Up @@ -233,7 +233,7 @@ func TestRegister_HappyFlow(t *testing.T) {
_, err = tlsReader.Write([]byte("https://localhost.localdomain/users/register\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand All @@ -243,6 +243,77 @@ func TestRegister_HappyFlow(t *testing.T) {
assert.Equal("30 /users\r\n", string(resp))
}

func TestRegister_HappyFlowRegistrationClosed(t *testing.T) {
assert := assert.New(t)

dbPath := fmt.Sprintf("/tmp/%s.sqlite3?_journal_mode=WAL", t.Name())
db, err := sql.Open("sqlite3", dbPath)
assert.NoError(err)
defer os.Remove(dbPath)

assert.NoError(migrations.Run(context.Background(), slog.Default(), db))

serverKeyPair, err := tls.X509KeyPair([]byte(serverCert), []byte(serverKey))
assert.NoError(err)

serverCfg := tls.Config{
Certificates: []tls.Certificate{serverKeyPair},
MinVersion: tls.VersionTLS12,
ClientAuth: tls.RequestClientCert,
}

erinKeyPair, err := tls.X509KeyPair([]byte(erinCert), []byte(erinKey))
assert.NoError(err)

clientCfg := tls.Config{
Certificates: []tls.Certificate{erinKeyPair},
InsecureSkipVerify: true,
}

socketPath := fmt.Sprintf("/tmp/%s.socket", t.Name())

localListener, err := net.Listen("unix", socketPath)
assert.NoError(err)
defer os.Remove(socketPath)

tlsListener := tls.NewListener(localListener, &serverCfg)
defer tlsListener.Close()

unixReader, err := net.Dial("unix", socketPath)
assert.NoError(err)
defer unixReader.Close()

tlsWriter, err := tlsListener.Accept()
assert.NoError(err)

tlsReader := tls.Client(unixReader, &clientCfg)
defer tlsReader.Close()

var wg sync.WaitGroup
wg.Add(2)
go func() {
assert.NoError(tlsReader.Handshake())
wg.Done()
}()
go func() {
assert.NoError(tlsWriter.(*tls.Conn).Handshake())
wg.Done()
}()
wg.Wait()

_, err = tlsReader.Write([]byte("https://localhost.localdomain/users/register\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(true), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

resp, err := io.ReadAll(tlsReader)
assert.NoError(err)

assert.Equal("40 Registration is closed\r\n", string(resp))
}

func TestRegister_AlreadyRegistered(t *testing.T) {
assert := assert.New(t)

Expand Down Expand Up @@ -307,7 +378,7 @@ func TestRegister_AlreadyRegistered(t *testing.T) {
_, err = user.Create(context.Background(), db, "https://localhost.localdomain/user/erin", "erin", "e")
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down Expand Up @@ -383,7 +454,7 @@ func TestRegister_Twice(t *testing.T) {
_, err = tlsReader.Write([]byte("https://localhost.localdomain/users/register\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down Expand Up @@ -470,7 +541,7 @@ func TestRegister_Throttling(t *testing.T) {
_, err = tlsReader.Write([]byte("https://localhost.localdomain/users/register\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down Expand Up @@ -557,7 +628,7 @@ func TestRegister_Throttling30Minutes(t *testing.T) {
_, err = tlsReader.Write([]byte("https://localhost.localdomain/users/register\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down Expand Up @@ -647,7 +718,7 @@ func TestRegister_Throttling1Hour(t *testing.T) {
_, err = tlsReader.Write([]byte("https://localhost.localdomain/users/register\r\n"))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down Expand Up @@ -731,7 +802,7 @@ func TestRegister_RedirectTwice(t *testing.T) {
_, err = tlsReader.Write([]byte(data.url))
assert.NoError(err)

gemini.Handle(context.Background(), front.NewHandler(), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())
gemini.Handle(context.Background(), front.NewHandler(false), tlsWriter, db, fed.NewResolver(nil), &wg, slog.Default())

tlsWriter.Close()

Expand Down
2 changes: 1 addition & 1 deletion test/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func newTestServer() *server {
return &server{
dbPath: path,
db: db,
handler: front.NewHandler(),
handler: front.NewHandler(false),
Alice: alice,
Bob: bob,
Carol: carol,
Expand Down

0 comments on commit 9537395

Please sign in to comment.