From b3e70356db1aa4891115a10902316090fccbc8bf Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 27 Oct 2021 09:40:41 +0300 Subject: [PATCH] feat: accept columns in WherePK --- internal/dbtest/query_test.go | 14 +++ .../testdata/snapshots/TestQuery-mysql5-100 | 1 + .../testdata/snapshots/TestQuery-mysql5-99 | 1 + .../testdata/snapshots/TestQuery-mysql8-100 | 1 + .../testdata/snapshots/TestQuery-mysql8-99 | 1 + .../testdata/snapshots/TestQuery-pg-100 | 1 + .../dbtest/testdata/snapshots/TestQuery-pg-99 | 1 + .../testdata/snapshots/TestQuery-pgx-100 | 1 + .../testdata/snapshots/TestQuery-pgx-99 | 1 + .../testdata/snapshots/TestQuery-sqlite-100 | 1 + .../testdata/snapshots/TestQuery-sqlite-99 | 1 + query_base.go | 99 +++++++++++++------ query_delete.go | 4 +- query_select.go | 4 +- query_update.go | 4 +- 15 files changed, 100 insertions(+), 35 deletions(-) create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql5-100 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql5-99 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql8-100 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql8-99 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pg-100 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pg-99 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pgx-100 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pgx-99 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-sqlite-100 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-sqlite-99 diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 6921e02b2..636927c2c 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -606,6 +606,20 @@ func TestQuery(t *testing.T) { IfNotExists(). ColumnExpr("column_name VARCHAR(123)") }, + func(db *bun.DB) schema.QueryAppender { + models := []Model{ + {ID: 1}, + {ID: 2}, + } + return db.NewSelect().Model(&models).WherePK() + }, + func(db *bun.DB) schema.QueryAppender { + models := []Model{ + {ID: 1, Str: "hello"}, + {ID: 2, Str: "world"}, + } + return db.NewSelect().Model(&models).WherePK("id", "str") + }, } timeRE := regexp.MustCompile(`'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-100 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-100 new file mode 100644 index 000000000..41c0af48d --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-100 @@ -0,0 +1 @@ +SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE (`model`.`id`, `model`.`str`) IN ((1, 'hello'), (2, 'world')) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-99 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-99 new file mode 100644 index 000000000..43df2042b --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-99 @@ -0,0 +1 @@ +SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE `model`.`id` IN (1, 2) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-100 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-100 new file mode 100644 index 000000000..41c0af48d --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-100 @@ -0,0 +1 @@ +SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE (`model`.`id`, `model`.`str`) IN ((1, 'hello'), (2, 'world')) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-99 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-99 new file mode 100644 index 000000000..43df2042b --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-99 @@ -0,0 +1 @@ +SELECT `model`.`id`, `model`.`str` FROM `models` AS `model` WHERE `model`.`id` IN (1, 2) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-100 b/internal/dbtest/testdata/snapshots/TestQuery-pg-100 new file mode 100644 index 000000000..8981da58e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-100 @@ -0,0 +1 @@ +SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE ("model"."id", "model"."str") IN ((1, 'hello'), (2, 'world')) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-99 b/internal/dbtest/testdata/snapshots/TestQuery-pg-99 new file mode 100644 index 000000000..109c2e23c --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-99 @@ -0,0 +1 @@ +SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE "model"."id" IN (1, 2) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-100 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-100 new file mode 100644 index 000000000..8981da58e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-100 @@ -0,0 +1 @@ +SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE ("model"."id", "model"."str") IN ((1, 'hello'), (2, 'world')) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-99 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-99 new file mode 100644 index 000000000..109c2e23c --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-99 @@ -0,0 +1 @@ +SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE "model"."id" IN (1, 2) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-100 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-100 new file mode 100644 index 000000000..8981da58e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-100 @@ -0,0 +1 @@ +SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE ("model"."id", "model"."str") IN ((1, 'hello'), (2, 'world')) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-99 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-99 new file mode 100644 index 000000000..109c2e23c --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-99 @@ -0,0 +1 @@ +SELECT "model"."id", "model"."str" FROM "models" AS "model" WHERE "model"."id" IN (1, 2) diff --git a/query_base.go b/query_base.go index 9739d93e2..4cf31d04e 100644 --- a/query_base.go +++ b/query_base.go @@ -14,8 +14,7 @@ import ( ) const ( - wherePKFlag internal.Flag = 1 << iota - forceDeleteFlag + forceDeleteFlag internal.Flag = 1 << iota deletedFlag allWithDeletedFlag ) @@ -580,7 +579,8 @@ func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) s type whereBaseQuery struct { baseQuery - where []schema.QueryWithSep + where []schema.QueryWithSep + whereFields []*schema.Field } func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { @@ -601,10 +601,46 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) q.addWhere(schema.SafeQueryWithSep("", nil, ")")) } +func (q *whereBaseQuery) addWhereCols(cols []string) { + if q.table == nil { + err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) + q.setErr(err) + return + } + + var fields []*schema.Field + + if cols == nil { + if err := q.table.CheckPKs(); err != nil { + q.setErr(err) + return + } + fields = q.table.PKs + } else { + fields = make([]*schema.Field, len(cols)) + for i, col := range cols { + field, err := q.table.Field(col) + if err != nil { + q.setErr(err) + return + } + fields[i] = field + } + } + + if q.whereFields != nil { + err := errors.New("bun: WherePK can only be called once") + q.setErr(err) + return + } + + q.whereFields = fields +} + func (q *whereBaseQuery) mustAppendWhere( fmter schema.Formatter, b []byte, withAlias bool, ) ([]byte, error) { - if len(q.where) == 0 && !q.flags.Has(wherePKFlag) { + if len(q.where) == 0 && q.whereFields == nil { err := errors.New("bun: Update and Delete queries require at least one Where") return nil, err } @@ -614,7 +650,7 @@ func (q *whereBaseQuery) mustAppendWhere( func (q *whereBaseQuery) appendWhere( fmter schema.Formatter, b []byte, withAlias bool, ) (_ []byte, err error) { - if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) { + if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() { return b, nil } @@ -656,11 +692,11 @@ func (q *whereBaseQuery) appendWhere( } } - if q.flags.Has(wherePKFlag) { + if q.whereFields != nil { if len(b) > startLen { b = append(b, " AND "...) } - b, err = q.appendWherePK(fmter, b, withAlias) + b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias) if err != nil { return nil, err } @@ -691,29 +727,30 @@ func appendWhere( return b, nil } -func (q *whereBaseQuery) appendWherePK( - fmter schema.Formatter, b []byte, withAlias bool, +func (q *whereBaseQuery) appendWhereFields( + fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool, ) (_ []byte, err error) { if q.table == nil { - err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) - return nil, err - } - if err := q.table.CheckPKs(); err != nil { + err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model) return nil, err } switch model := q.tableModel.(type) { case *structTableModel: - return q.appendWherePKStruct(fmter, b, model, withAlias) + return q.appendWhereStructFields(fmter, b, model, fields, withAlias) case *sliceTableModel: - return q.appendWherePKSlice(fmter, b, model, withAlias) + return q.appendWhereSliceFields(fmter, b, model, fields, withAlias) + default: + return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel) } - - return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel) } -func (q *whereBaseQuery) appendWherePKStruct( - fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool, +func (q *whereBaseQuery) appendWhereStructFields( + fmter schema.Formatter, + b []byte, + model *structTableModel, + fields []*schema.Field, + withAlias bool, ) (_ []byte, err error) { if !model.strct.IsValid() { return nil, errNilModel @@ -721,7 +758,7 @@ func (q *whereBaseQuery) appendWherePKStruct( isTemplate := fmter.IsNop() b = append(b, '(') - for i, f := range q.table.PKs { + for i, f := range fields { if i > 0 { b = append(b, " AND "...) } @@ -741,18 +778,22 @@ func (q *whereBaseQuery) appendWherePKStruct( return b, nil } -func (q *whereBaseQuery) appendWherePKSlice( - fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool, +func (q *whereBaseQuery) appendWhereSliceFields( + fmter schema.Formatter, + b []byte, + model *sliceTableModel, + fields []*schema.Field, + withAlias bool, ) (_ []byte, err error) { - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, '(') } if withAlias { - b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + b = appendColumns(b, q.table.SQLAlias, fields) } else { - b = appendColumns(b, "", q.table.PKs) + b = appendColumns(b, "", fields) } - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, ')') } @@ -771,10 +812,10 @@ func (q *whereBaseQuery) appendWherePKSlice( el := indirect(slice.Index(i)) - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, '(') } - for i, f := range q.table.PKs { + for i, f := range fields { if i > 0 { b = append(b, ", "...) } @@ -784,7 +825,7 @@ func (q *whereBaseQuery) appendWherePKSlice( b = f.AppendValue(fmter, b, el) } } - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, ')') } } diff --git a/query_delete.go b/query_delete.go index a8e14f896..6af3dbd2e 100644 --- a/query_delete.go +++ b/query_delete.go @@ -66,8 +66,8 @@ func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQ //------------------------------------------------------------------------------ -func (q *DeleteQuery) WherePK() *DeleteQuery { - q.flags = q.flags.Set(wherePKFlag) +func (q *DeleteQuery) WherePK(cols ...string) *DeleteQuery { + q.addWhereCols(cols) return q } diff --git a/query_select.go b/query_select.go index 399e3eb0f..c69fed024 100644 --- a/query_select.go +++ b/query_select.go @@ -116,8 +116,8 @@ func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery { //------------------------------------------------------------------------------ -func (q *SelectQuery) WherePK() *SelectQuery { - q.flags = q.flags.Set(wherePKFlag) +func (q *SelectQuery) WherePK(cols ...string) *SelectQuery { + q.addWhereCols(cols) return q } diff --git a/query_update.go b/query_update.go index df583db25..2de4ace9a 100644 --- a/query_update.go +++ b/query_update.go @@ -107,8 +107,8 @@ func (q *UpdateQuery) OmitZero() *UpdateQuery { //------------------------------------------------------------------------------ -func (q *UpdateQuery) WherePK() *UpdateQuery { - q.flags = q.flags.Set(wherePKFlag) +func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery { + q.addWhereCols(cols) return q }