diff --git a/integration/basic_test.go b/integration/basic_test.go index e86a4cad95fc..c6d5a47dcead 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -446,6 +446,46 @@ func TestDatabaseName(t *testing.T) { t.Parallel() + t.Run("NoErr", func(t *testing.T) { + ctx, collection := setup.Setup(t) + for name, tc := range map[string]struct { + db string // database name, defaults to empty string + skip string // optional, skip test with a specified reason + }{ + "Dash": { + db: "--", + }, + "Underscore": { + db: "__", + }, + "Sqlite": { + db: "sqlite_", + }, + "Number": { + db: "0prefix", + }, + "63ok": { + db: strings.Repeat("a", 63), + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + + t.Parallel() + + // there is no explicit command to create database, so create collection instead + err := collection.Database().Client().Database(tc.db).CreateCollection(ctx, collection.Name()) + require.NoError(t, err) + + err = collection.Database().Client().Database(tc.db).Drop(ctx) + require.NoError(t, err) + }) + } + }) + t.Run("Err", func(t *testing.T) { ctx, collection := setup.Setup(t) @@ -471,6 +511,25 @@ func TestDatabaseName(t *testing.T) { }, altMessage: fmt.Sprintf("Invalid namespace: %s.%s", dbName64, "TestDatabaseName-Err"), }, + "WithASlash": { + db: "/", + err: &mongo.CommandError{ + Name: "InvalidNamespace", + Code: 73, + Message: `Invalid namespace specified '/.TestDatabaseName-Err'`, + }, + altMessage: `Invalid namespace: /.TestDatabaseName-Err`, + }, + + "WithABackslash": { + db: "\\", + err: &mongo.CommandError{ + Name: "InvalidNamespace", + Code: 73, + Message: `Invalid namespace specified '\.TestDatabaseName-Err'`, + }, + altMessage: `Invalid namespace: \.TestDatabaseName-Err`, + }, "WithADollarSign": { db: "name_with_a-$", err: &mongo.CommandError{ @@ -516,15 +575,6 @@ func TestDatabaseName(t *testing.T) { }) } }) - - t.Run("63ok", func(t *testing.T) { - ctx, collection := setup.Setup(t) - - dbName63 := strings.Repeat("a", 63) - err := collection.Database().Client().Database(dbName63).CreateCollection(ctx, collection.Name()) - require.NoError(t, err) - collection.Database().Client().Database(dbName63).Drop(ctx) - }) } func TestDebugError(t *testing.T) { diff --git a/internal/handlers/pg/pgdb/databases.go b/internal/handlers/pg/pgdb/databases.go index aabcf59692d2..ba30acf5e516 100644 --- a/internal/handlers/pg/pgdb/databases.go +++ b/internal/handlers/pg/pgdb/databases.go @@ -29,7 +29,7 @@ import ( ) // validateDatabaseNameRe validates FerretDB database / PostgreSQL schema names. -var validateDatabaseNameRe = regexp.MustCompile("^[a-zA-Z_-][a-zA-Z0-9_-]{0,62}$") +var validateDatabaseNameRe = regexp.MustCompile("^[a-zA-Z0-9_-]{1,63}$") // Databases returns a sorted list of FerretDB databases / PostgreSQL schemas. func Databases(ctx context.Context, tx pgx.Tx) ([]string, error) { diff --git a/internal/handlers/sqlite/msg_create.go b/internal/handlers/sqlite/msg_create.go index a7c509842d78..64e42528a721 100644 --- a/internal/handlers/sqlite/msg_create.go +++ b/internal/handlers/sqlite/msg_create.go @@ -17,6 +17,7 @@ package sqlite import ( "context" "fmt" + "regexp" "github.com/FerretDB/FerretDB/internal/backends" "github.com/FerretDB/FerretDB/internal/handlers/common" @@ -27,6 +28,9 @@ import ( "github.com/FerretDB/FerretDB/internal/wire" ) +// validateDatabaseNameRe validates FerretDB database name. +var validateDatabaseNameRe = regexp.MustCompile("^[a-zA-Z0-9_-]{1,63}$") + // MsgCreate implements HandlerInterface. func (h *Handler) MsgCreate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) { document, err := msg.Document() @@ -78,6 +82,11 @@ func (h *Handler) MsgCreate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, return nil, err } + if !validateDatabaseNameRe.MatchString(dbName) { + msg := fmt.Sprintf("Invalid namespace: %s.%s", dbName, collectionName) + return nil, commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrInvalidNamespace, msg, "create") + } + db := h.b.Database(dbName) defer db.Close() @@ -98,11 +107,11 @@ func (h *Handler) MsgCreate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, case backends.ErrorCodeIs(err, backends.ErrorCodeCollectionAlreadyExists): msg := fmt.Sprintf("Collection %s.%s already exists.", dbName, collectionName) - return nil, commonerrors.NewCommandErrorMsg(commonerrors.ErrNamespaceExists, msg) + return nil, commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrNamespaceExists, msg, "create") case backends.ErrorCodeIs(err, backends.ErrorCodeCollectionNameIsInvalid): msg := fmt.Sprintf("Invalid namespace: %s.%s", dbName, collectionName) - return nil, commonerrors.NewCommandErrorMsg(commonerrors.ErrInvalidNamespace, msg) + return nil, commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrInvalidNamespace, msg, "create") default: return nil, lazyerrors.Error(err)