Skip to content

Commit

Permalink
Add support for alter_table set check operations in multi-operati…
Browse files Browse the repository at this point in the history
…on migrations (#622)

Ensure that multi-operation migrations combining `alter_column`
operations setting `CHECK` constraints work in combination with other
operations.

Add testcases for:

* rename table, set `CHECK` constraint
* rename table, rename column, set `CHECK` constraint

Previously these migrations would fail as the `alter_column` operation
was unaware of the changes made by the preceding operation.

Part of #239
  • Loading branch information
andrew-farries authored Jan 23, 2025
1 parent db21651 commit c09619a
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pkg/migrations/op_set_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, latestSche
table := s.GetTable(o.Table)

// Add the check constraint to the new column as NOT VALID.
if err := o.addCheckConstraint(ctx, conn); err != nil {
if err := o.addCheckConstraint(ctx, conn, s); err != nil {
return nil, fmt.Errorf("failed to add check constraint: %w", err)
}

Expand Down Expand Up @@ -82,9 +82,11 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
return nil
}

func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error {
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB, s *schema.Schema) error {
table := s.GetTable(o.Table)

_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(table.Name),
pq.QuoteIdentifier(o.Check.Name),
rewriteCheckExpression(o.Check.Constraint, o.Column),
))
Expand Down
216 changes: 216 additions & 0 deletions pkg/migrations/op_set_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,222 @@ func TestSetCheckConstraint(t *testing.T) {
})
}

func TestSetCheckInMultiOperationMigrations(t *testing.T) {
t.Parallel()

ExecuteTests(t, TestCases{
{
name: "rename table, set not null",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "items",
Columns: []migrations.Column{
{
Name: "id",
Type: "int",
Pk: true,
},
{
Name: "name",
Type: "varchar(255)",
Nullable: true,
},
},
},
},
},
{
Name: "02_multi_operation",
Operations: migrations.Operations{
&migrations.OpRenameTable{
From: "items",
To: "products",
},
&migrations.OpAlterColumn{
Table: "products",
Column: "name",
Check: &migrations.CheckConstraint{
Name: "check_name_length",
Constraint: "LENGTH(name) > 2",
},
Up: "SELECT CASE WHEN length(name) > 2 THEN name ELSE name || '---' END",
Down: "name",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Can insert a row into the new view that meets the check constraint
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "1",
"name": "abc",
})

// Can't insert a row into the new view that violates the check constraint
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "2",
"name": "x",
}, testutils.CheckViolationErrorCode)

// Can insert a row into the old view that violates the check constraint
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
"id": "3",
"name": "x",
})

// The new view has the expected rows
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "abc"},
{"id": 3, "name": "x---"},
}, rows)

// The old view has the expected rows
rows = MustSelect(t, db, schema, "01_create_table", "items")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "abc"},
{"id": 3, "name": "x"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The table has been cleaned up
TableMustBeCleanedUp(t, db, schema, "items", "name")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Can insert a row into the new view that meets the check constraint
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "4",
"name": "def",
})

// Can't insert a row into the new view that violates the check constraint
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "5",
"name": "x",
}, testutils.CheckViolationErrorCode)

// The new view has the expected rows
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "abc"},
{"id": 3, "name": "x---"},
{"id": 4, "name": "def"},
}, rows)
},
},
{
name: "rename table, rename column set not null",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "items",
Columns: []migrations.Column{
{
Name: "id",
Type: "int",
Pk: true,
},
{
Name: "name",
Type: "varchar(255)",
Nullable: true,
},
},
},
},
},
{
Name: "02_multi_operation",
Operations: migrations.Operations{
&migrations.OpRenameTable{
From: "items",
To: "products",
},
&migrations.OpRenameColumn{
Table: "products",
From: "name",
To: "item_name",
},
&migrations.OpAlterColumn{
Table: "products",
Column: "item_name",
Check: &migrations.CheckConstraint{
Name: "check_name_length",
Constraint: "LENGTH(item_name) > 2",
},
Up: "SELECT CASE WHEN length(item_name) > 2 THEN item_name ELSE item_name || '---' END",
Down: "item_name",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Can insert a row into the new view that meets the check constraint
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "1",
"item_name": "abc",
})

// Can't insert a row into the new view that violates the check constraint
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "2",
"item_name": "x",
}, testutils.CheckViolationErrorCode)

// Can insert a row into the old view that violates the check constraint
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
"id": "3",
"name": "x",
})

// The new view has the expected rows
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
assert.Equal(t, []map[string]any{
{"id": 1, "item_name": "abc"},
{"id": 3, "item_name": "x---"},
}, rows)

// The old view has the expected rows
rows = MustSelect(t, db, schema, "01_create_table", "items")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "abc"},
{"id": 3, "name": "x"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The table has been cleaned up
TableMustBeCleanedUp(t, db, schema, "items", "name")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Can insert a row into the new view that meets the check constraint
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "4",
"item_name": "def",
})

// Can't insert a row into the new view that violates the check constraint
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
"id": "5",
"item_name": "x",
}, testutils.CheckViolationErrorCode)

// The new view has the expected rows
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
assert.Equal(t, []map[string]any{
{"id": 1, "item_name": "abc"},
{"id": 3, "item_name": "x---"},
{"id": 4, "item_name": "def"},
}, rows)
},
},
})
}

func TestSetCheckConstraintValidation(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit c09619a

Please sign in to comment.