diff --git a/.gitignore b/.gitignore index 1561473..7867414 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ _testmain.go # VS Code debug debug_test +.vscode/ # Mac .DS_Store diff --git a/cond.go b/cond.go index 895bb30..2396b6d 100644 --- a/cond.go +++ b/cond.go @@ -25,6 +25,10 @@ func NewCond() *Cond { // Equal is used to construct the expression "field = value". func (c *Cond) Equal(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -46,6 +50,10 @@ func (c *Cond) EQ(field string, value interface{}) string { // NotEqual is used to construct the expression "field <> value". func (c *Cond) NotEqual(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -67,6 +75,10 @@ func (c *Cond) NEQ(field string, value interface{}) string { // GreaterThan is used to construct the expression "field > value". func (c *Cond) GreaterThan(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -88,6 +100,10 @@ func (c *Cond) GT(field string, value interface{}) string { // GreaterEqualThan is used to construct the expression "field >= value". func (c *Cond) GreaterEqualThan(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -109,6 +125,10 @@ func (c *Cond) GTE(field string, value interface{}) string { // LessThan is used to construct the expression "field < value". func (c *Cond) LessThan(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -130,6 +150,9 @@ func (c *Cond) LT(field string, value interface{}) string { // LessEqualThan is used to construct the expression "field <= value". func (c *Cond) LessEqualThan(field string, value interface{}) string { + if len(field) == 0 { + return "" + } return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -151,6 +174,10 @@ func (c *Cond) LTE(field string, value interface{}) string { // In is used to construct the expression "field IN (value...)". func (c *Cond) In(field string, values ...interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -163,6 +190,10 @@ func (c *Cond) In(field string, values ...interface{}) string { // NotIn is used to construct the expression "field NOT IN (value...)". func (c *Cond) NotIn(field string, values ...interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -175,6 +206,10 @@ func (c *Cond) NotIn(field string, values ...interface{}) string { // Like is used to construct the expression "field LIKE value". func (c *Cond) Like(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -190,6 +225,10 @@ func (c *Cond) Like(field string, value interface{}) string { // the ILike method will return "LOWER(field) LIKE LOWER(value)" // to simulate the behavior of the ILIKE operator. func (c *Cond) ILike(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { switch ctx.Flavor { @@ -212,6 +251,10 @@ func (c *Cond) ILike(field string, value interface{}) string { // NotLike is used to construct the expression "field NOT LIKE value". func (c *Cond) NotLike(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -227,6 +270,10 @@ func (c *Cond) NotLike(field string, value interface{}) string { // the NotILike method will return "LOWER(field) NOT LIKE LOWER(value)" // to simulate the behavior of the ILIKE operator. func (c *Cond) NotILike(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { switch ctx.Flavor { @@ -249,6 +296,10 @@ func (c *Cond) NotILike(field string, value interface{}) string { // IsNull is used to construct the expression "field IS NULL". func (c *Cond) IsNull(field string) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -259,6 +310,9 @@ func (c *Cond) IsNull(field string) string { // IsNotNull is used to construct the expression "field IS NOT NULL". func (c *Cond) IsNotNull(field string) string { + if len(field) == 0 { + return "" + } return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -269,6 +323,10 @@ func (c *Cond) IsNotNull(field string) string { // Between is used to construct the expression "field BETWEEN lower AND upper". func (c *Cond) Between(field string, lower, upper interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -282,6 +340,10 @@ func (c *Cond) Between(field string, lower, upper interface{}) string { // NotBetween is used to construct the expression "field NOT BETWEEN lower AND upper". func (c *Cond) NotBetween(field string, lower, upper interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -299,10 +361,15 @@ func (c *Cond) Or(orExpr ...string) string { return "" } + exprByteLen := estimateStringsBytes(orExpr) + if exprByteLen == 0 { + return "" + } + buf := newStringBuilder() // Ensure that there is only 1 memory allocation. - size := len(lparen) + len(rparen) + (len(orExpr)-1)*len(opOR) + estimateStringsBytes(orExpr) + size := len(lparen) + len(rparen) + (len(orExpr)-1)*len(opOR) + exprByteLen buf.Grow(size) buf.WriteString(lparen) @@ -317,10 +384,15 @@ func (c *Cond) And(andExpr ...string) string { return "" } + exprByteLen := estimateStringsBytes(andExpr) + if exprByteLen == 0 { + return "" + } + buf := newStringBuilder() // Ensure that there is only 1 memory allocation. - size := len(lparen) + len(rparen) + (len(andExpr)-1)*len(opAND) + estimateStringsBytes(andExpr) + size := len(lparen) + len(rparen) + (len(andExpr)-1)*len(opAND) + exprByteLen buf.Grow(size) buf.WriteString(lparen) @@ -331,6 +403,9 @@ func (c *Cond) And(andExpr ...string) string { // Not is used to construct the expression "NOT expr". func (c *Cond) Not(notExpr string) string { + if len(notExpr) == 0 { + return "" + } buf := newStringBuilder() // Ensure that there is only 1 memory allocation. @@ -366,6 +441,10 @@ func (c *Cond) NotExists(subquery interface{}) string { // Any is used to construct the expression "field op ANY (value...)". func (c *Cond) Any(field, op string, values ...interface{}) string { + if len(field) == 0 || len(op) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -380,6 +459,10 @@ func (c *Cond) Any(field, op string, values ...interface{}) string { // All is used to construct the expression "field op ALL (value...)". func (c *Cond) All(field, op string, values ...interface{}) string { + if len(field) == 0 || len(op) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -394,6 +477,10 @@ func (c *Cond) All(field, op string, values ...interface{}) string { // Some is used to construct the expression "field op SOME (value...)". func (c *Cond) Some(field, op string, values ...interface{}) string { + if len(field) == 0 || len(op) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -413,6 +500,10 @@ func (c *Cond) Some(field, op string, values ...interface{}) string { // "CASE ... WHEN ... ELSE ... END" expression to simulate the behavior of // the IS DISTINCT FROM operator. func (c *Cond) IsDistinctFrom(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { switch ctx.Flavor { @@ -458,6 +549,10 @@ func (c *Cond) IsDistinctFrom(field string, value interface{}) string { // "CASE ... WHEN ... ELSE ... END" expression to simulate the behavior of // the IS NOT DISTINCT FROM operator. func (c *Cond) IsNotDistinctFrom(field string, value interface{}) string { + if len(field) == 0 { + return "" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { switch ctx.Flavor { diff --git a/cond_test.go b/cond_test.go index 59b0b0a..2587855 100644 --- a/cond_test.go +++ b/cond_test.go @@ -10,48 +10,51 @@ import ( "github.com/huandu/go-assert" ) +type TestPair struct { + Expected string + Actual func(cond *Cond) string +} + func TestCond(t *testing.T) { a := assert.New(t) cases := map[string]func(cond *Cond) string{ - "$a = $1": func(cond *Cond) string { return cond.Equal("$a", 123) }, - "$b = $1": func(cond *Cond) string { return cond.E("$b", 123) }, - "$c = $1": func(cond *Cond) string { return cond.EQ("$c", 123) }, - "$a <> $1": func(cond *Cond) string { return cond.NotEqual("$a", 123) }, - "$b <> $1": func(cond *Cond) string { return cond.NE("$b", 123) }, - "$c <> $1": func(cond *Cond) string { return cond.NEQ("$c", 123) }, - "$a > $1": func(cond *Cond) string { return cond.GreaterThan("$a", 123) }, - "$b > $1": func(cond *Cond) string { return cond.G("$b", 123) }, - "$c > $1": func(cond *Cond) string { return cond.GT("$c", 123) }, - "$a >= $1": func(cond *Cond) string { return cond.GreaterEqualThan("$a", 123) }, - "$b >= $1": func(cond *Cond) string { return cond.GE("$b", 123) }, - "$c >= $1": func(cond *Cond) string { return cond.GTE("$c", 123) }, - "$a < $1": func(cond *Cond) string { return cond.LessThan("$a", 123) }, - "$b < $1": func(cond *Cond) string { return cond.L("$b", 123) }, - "$c < $1": func(cond *Cond) string { return cond.LT("$c", 123) }, - "$a <= $1": func(cond *Cond) string { return cond.LessEqualThan("$a", 123) }, - "$b <= $1": func(cond *Cond) string { return cond.LE("$b", 123) }, - "$c <= $1": func(cond *Cond) string { return cond.LTE("$c", 123) }, - "$a IN ($1, $2, $3)": func(cond *Cond) string { return cond.In("$a", 1, 2, 3) }, - "$a NOT IN ($1, $2, $3)": func(cond *Cond) string { return cond.NotIn("$a", 1, 2, 3) }, - "$a LIKE $1": func(cond *Cond) string { return cond.Like("$a", "%Huan%") }, - "$a ILIKE $1": func(cond *Cond) string { return cond.ILike("$a", "%Huan%") }, - "$a NOT LIKE $1": func(cond *Cond) string { return cond.NotLike("$a", "%Huan%") }, - "$a NOT ILIKE $1": func(cond *Cond) string { return cond.NotILike("$a", "%Huan%") }, - "$a IS NULL": func(cond *Cond) string { return cond.IsNull("$a") }, - "$a IS NOT NULL": func(cond *Cond) string { return cond.IsNotNull("$a") }, - "$a BETWEEN $1 AND $2": func(cond *Cond) string { return cond.Between("$a", 123, 456) }, - "$a NOT BETWEEN $1 AND $2": func(cond *Cond) string { return cond.NotBetween("$a", 123, 456) }, - "(1 = 1 OR 2 = 2 OR 3 = 3)": func(cond *Cond) string { return cond.Or("1 = 1", "2 = 2", "3 = 3") }, - "(1 = 1 AND 2 = 2 AND 3 = 3)": func(cond *Cond) string { return cond.And("1 = 1", "2 = 2", "3 = 3") }, - "NOT 1 = 1": func(cond *Cond) string { return cond.Not("1 = 1") }, - "EXISTS ($1)": func(cond *Cond) string { return cond.Exists(1) }, - "NOT EXISTS ($1)": func(cond *Cond) string { return cond.NotExists(1) }, - "$a > ANY ($1, $2)": func(cond *Cond) string { return cond.Any("$a", ">", 1, 2) }, - "$a < ALL ($1)": func(cond *Cond) string { return cond.All("$a", "<", 1) }, - "$a > SOME ($1, $2, $3)": func(cond *Cond) string { return cond.Some("$a", ">", 1, 2, 3) }, - "$a IS DISTINCT FROM $1": func(cond *Cond) string { return cond.IsDistinctFrom("$a", 1) }, - "$a IS NOT DISTINCT FROM $1": func(cond *Cond) string { return cond.IsNotDistinctFrom("$a", 1) }, - "$1": func(cond *Cond) string { return cond.Var(123) }, + "$a = $1": func(cond *Cond) string { return cond.Equal("$a", 123) }, + "$b = $1": func(cond *Cond) string { return cond.E("$b", 123) }, + "$c = $1": func(cond *Cond) string { return cond.EQ("$c", 123) }, + "$a <> $1": func(cond *Cond) string { return cond.NotEqual("$a", 123) }, + "$b <> $1": func(cond *Cond) string { return cond.NE("$b", 123) }, + "$c <> $1": func(cond *Cond) string { return cond.NEQ("$c", 123) }, + "$a > $1": func(cond *Cond) string { return cond.GreaterThan("$a", 123) }, + "$b > $1": func(cond *Cond) string { return cond.G("$b", 123) }, + "$c > $1": func(cond *Cond) string { return cond.GT("$c", 123) }, + "$a >= $1": func(cond *Cond) string { return cond.GreaterEqualThan("$a", 123) }, + "$b >= $1": func(cond *Cond) string { return cond.GE("$b", 123) }, + "$c >= $1": func(cond *Cond) string { return cond.GTE("$c", 123) }, + "$a < $1": func(cond *Cond) string { return cond.LessThan("$a", 123) }, + "$b < $1": func(cond *Cond) string { return cond.L("$b", 123) }, + "$c < $1": func(cond *Cond) string { return cond.LT("$c", 123) }, + "$a <= $1": func(cond *Cond) string { return cond.LessEqualThan("$a", 123) }, + "$b <= $1": func(cond *Cond) string { return cond.LE("$b", 123) }, + "$c <= $1": func(cond *Cond) string { return cond.LTE("$c", 123) }, + "$a IN ($1, $2, $3)": func(cond *Cond) string { return cond.In("$a", 1, 2, 3) }, + "$a NOT IN ($1, $2, $3)": func(cond *Cond) string { return cond.NotIn("$a", 1, 2, 3) }, + "$a LIKE $1": func(cond *Cond) string { return cond.Like("$a", "%Huan%") }, + "$a ILIKE $1": func(cond *Cond) string { return cond.ILike("$a", "%Huan%") }, + "$a NOT LIKE $1": func(cond *Cond) string { return cond.NotLike("$a", "%Huan%") }, + "$a NOT ILIKE $1": func(cond *Cond) string { return cond.NotILike("$a", "%Huan%") }, + "$a IS NULL": func(cond *Cond) string { return cond.IsNull("$a") }, + "$a IS NOT NULL": func(cond *Cond) string { return cond.IsNotNull("$a") }, + "$a BETWEEN $1 AND $2": func(cond *Cond) string { return cond.Between("$a", 123, 456) }, + "$a NOT BETWEEN $1 AND $2": func(cond *Cond) string { return cond.NotBetween("$a", 123, 456) }, + "NOT 1 = 1": func(cond *Cond) string { return cond.Not("1 = 1") }, + "EXISTS ($1)": func(cond *Cond) string { return cond.Exists(1) }, + "NOT EXISTS ($1)": func(cond *Cond) string { return cond.NotExists(1) }, + "$a > ANY ($1, $2)": func(cond *Cond) string { return cond.Any("$a", ">", 1, 2) }, + "$a < ALL ($1)": func(cond *Cond) string { return cond.All("$a", "<", 1) }, + "$a > SOME ($1, $2, $3)": func(cond *Cond) string { return cond.Some("$a", ">", 1, 2, 3) }, + "$a IS DISTINCT FROM $1": func(cond *Cond) string { return cond.IsDistinctFrom("$a", 1) }, + "$a IS NOT DISTINCT FROM $1": func(cond *Cond) string { return cond.IsNotDistinctFrom("$a", 1) }, + "$1": func(cond *Cond) string { return cond.Var(123) }, } for expected, f := range cases { @@ -60,6 +63,100 @@ func TestCond(t *testing.T) { } } +func TestOrCond(t *testing.T) { + a := assert.New(t) + cases := []TestPair{ + {Expected: "(1 = 1 OR 2 = 2 OR 3 = 3)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "2 = 2", "3 = 3") }}, + + {Expected: "(1 = 1 OR 2 = 2)", Actual: func(cond *Cond) string { return cond.Or("", "1 = 1", "2 = 2") }}, + {Expected: "(1 = 1 OR 2 = 2)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "2 = 2", "") }}, + {Expected: "(1 = 1 OR 2 = 2)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "", "2 = 2") }}, + + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "", "") }}, + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("", "1 = 1", "") }}, + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("", "", "1 = 1") }}, + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("1 = 1") }}, + + {Expected: "", Actual: func(cond *Cond) string { return cond.Or("") }}, + {Expected: "", Actual: func(cond *Cond) string { return cond.Or() }}, + {Expected: "", Actual: func(cond *Cond) string { return cond.Or("", "", "") }}, + } + + for _, f := range cases { + actual := callCond(f.Actual) + a.Equal(actual, f.Expected) + } +} + +func TestAndCond(t *testing.T) { + a := assert.New(t) + cases := []TestPair{ + {Expected: "(1 = 1 AND 2 = 2 AND 3 = 3)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "2 = 2", "3 = 3") }}, + + {Expected: "(1 = 1 AND 2 = 2)", Actual: func(cond *Cond) string { return cond.And("", "1 = 1", "2 = 2") }}, + {Expected: "(1 = 1 AND 2 = 2)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "2 = 2", "") }}, + {Expected: "(1 = 1 AND 2 = 2)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "", "2 = 2") }}, + + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "", "") }}, + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("", "1 = 1", "") }}, + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("", "", "1 = 1") }}, + {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("1 = 1") }}, + + {Expected: "", Actual: func(cond *Cond) string { return cond.And("") }}, + {Expected: "", Actual: func(cond *Cond) string { return cond.And() }}, + {Expected: "", Actual: func(cond *Cond) string { return cond.And("", "", "") }}, + } + + for _, f := range cases { + actual := callCond(f.Actual) + a.Equal(actual, f.Expected) + } +} + +func TestEmptyCond(t *testing.T) { + a := assert.New(t) + cases := []func(cond *Cond) string{ + func(cond *Cond) string { return cond.Equal("", 123) }, + func(cond *Cond) string { return cond.NotEqual("", 123) }, + func(cond *Cond) string { return cond.GreaterThan("", 123) }, + func(cond *Cond) string { return cond.GreaterEqualThan("", 123) }, + func(cond *Cond) string { return cond.LessThan("", 123) }, + func(cond *Cond) string { return cond.LessEqualThan("", 123) }, + func(cond *Cond) string { return cond.In("", 1, 2, 3) }, + func(cond *Cond) string { return cond.NotIn("", 1, 2, 3) }, + func(cond *Cond) string { return cond.Like("", "%Huan%") }, + func(cond *Cond) string { return cond.ILike("", "%Huan%") }, + func(cond *Cond) string { return cond.NotLike("", "%Huan%") }, + func(cond *Cond) string { return cond.NotILike("", "%Huan%") }, + func(cond *Cond) string { return cond.IsNull("") }, + func(cond *Cond) string { return cond.IsNotNull("") }, + func(cond *Cond) string { return cond.Between("", 123, 456) }, + func(cond *Cond) string { return cond.NotBetween("", 123, 456) }, + func(cond *Cond) string { return cond.Not("") }, + + func(cond *Cond) string { return cond.Any("", "", 1, 2) }, + func(cond *Cond) string { return cond.Any("", ">", 1, 2) }, + func(cond *Cond) string { return cond.Any("$a", "", 1, 2) }, + + func(cond *Cond) string { return cond.All("", "", 1) }, + func(cond *Cond) string { return cond.All("", ">", 1) }, + func(cond *Cond) string { return cond.All("$a", "", 1) }, + + func(cond *Cond) string { return cond.Some("", "", 1, 2, 3) }, + func(cond *Cond) string { return cond.Some("", ">", 1, 2, 3) }, + func(cond *Cond) string { return cond.Some("$a", "", 1, 2, 3) }, + + func(cond *Cond) string { return cond.IsDistinctFrom("", 1) }, + func(cond *Cond) string { return cond.IsNotDistinctFrom("", 1) }, + } + + expected := "" + for _, f := range cases { + actual := callCond(f) + a.Equal(actual, expected) + } +} + func callCond(fn func(cond *Cond) string) (actual string) { cond := &Cond{ Args: &Args{}, diff --git a/delete.go b/delete.go index b49c4d2..352f8e4 100644 --- a/delete.go +++ b/delete.go @@ -80,7 +80,7 @@ func (db *DeleteBuilder) DeleteFrom(table string) *DeleteBuilder { // Where sets expressions of WHERE in DELETE. func (db *DeleteBuilder) Where(andExpr ...string) *DeleteBuilder { - if len(andExpr) == 0 { + if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 { return db } diff --git a/select.go b/select.go index 3817584..00e633b 100644 --- a/select.go +++ b/select.go @@ -189,7 +189,7 @@ func (sb *SelectBuilder) JoinWithOption(option JoinOption, table string, onExpr // Where sets expressions of WHERE in SELECT. func (sb *SelectBuilder) Where(andExpr ...string) *SelectBuilder { - if len(andExpr) == 0 { + if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 { return sb } diff --git a/stringbuilder.go b/stringbuilder.go index cec204a..4c2c7a2 100644 --- a/stringbuilder.go +++ b/stringbuilder.go @@ -39,11 +39,20 @@ func (sb *stringBuilder) WriteStrings(ss []string, sep string) { return } - sb.WriteString(ss[0]) + firstAdded := false + if len(ss[0]) != 0 { + sb.WriteString(ss[0]) + firstAdded = true + } for _, s := range ss[1:] { - sb.WriteString(sep) - sb.WriteString(s) + if len(s) != 0 { + if firstAdded { + sb.WriteString(sep) + } + sb.WriteString(s) + firstAdded = true + } } } diff --git a/update.go b/update.go index ec17e8b..56389ad 100644 --- a/update.go +++ b/update.go @@ -97,7 +97,7 @@ func (ub *UpdateBuilder) SetMore(assignment ...string) *UpdateBuilder { // Where sets expressions of WHERE in UPDATE. func (ub *UpdateBuilder) Where(andExpr ...string) *UpdateBuilder { - if len(andExpr) == 0 { + if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 { return ub } diff --git a/whereclause.go b/whereclause.go index 70451a9..a93c802 100644 --- a/whereclause.go +++ b/whereclause.go @@ -92,6 +92,12 @@ func (wc *WhereClause) AddWhereExpr(args *Args, andExpr ...string) *WhereClause return wc } + andExprsBytesLen := estimateStringsBytes(andExpr) + + if andExprsBytesLen == 0 { + return wc + } + // Merge with last clause if possible. if len(wc.clauses) > 0 { lastClause := &wc.clauses[len(wc.clauses)-1] diff --git a/whereclause_test.go b/whereclause_test.go index 1a64054..17b7d05 100644 --- a/whereclause_test.go +++ b/whereclause_test.go @@ -254,3 +254,60 @@ func TestEmptyWhereExpr(t *testing.T) { a.Equal(ub.String(), "UPDATE t SET foo = 1") a.Equal(db.String(), "DELETE FROM t") } + +func TestEmptyStringsWhere(t *testing.T) { + a := assert.New(t) + emptyExpr := []string{"", "", ""} + + sb := Select("*").From("t").Where(emptyExpr...) + ub := Update("t").Set("foo = 1").Where(emptyExpr...) + db := DeleteFrom("t").Where(emptyExpr...) + + a.Equal(sb.String(), "SELECT * FROM t") + a.Equal(ub.String(), "UPDATE t SET foo = 1") + a.Equal(db.String(), "DELETE FROM t") +} + +func TestEmptyAddWhereExpr(t *testing.T) { + a := assert.New(t) + var emptyExpr []string + sb := Select("*").From("t") + ub := Update("t").Set("foo = 1") + db := DeleteFrom("t") + + cond := NewCond() + whereClause := NewWhereClause().AddWhereExpr( + cond.Args, + emptyExpr..., + ) + + sb.AddWhereClause(whereClause) + ub.AddWhereClause(whereClause) + db.AddWhereClause(whereClause) + + a.Equal(sb.String(), "SELECT * FROM t ") + a.Equal(ub.String(), "UPDATE t SET foo = 1 ") + a.Equal(db.String(), "DELETE FROM t ") +} + +func TestEmptyStringsWhereAddWhereExpr(t *testing.T) { + a := assert.New(t) + emptyExpr := []string{"", "", ""} + sb := Select("*").From("t") + ub := Update("t").Set("foo = 1") + db := DeleteFrom("t") + + cond := NewCond() + whereClause := NewWhereClause().AddWhereExpr( + cond.Args, + emptyExpr..., + ) + + sb.AddWhereClause(whereClause) + ub.AddWhereClause(whereClause) + db.AddWhereClause(whereClause) + + a.Equal(sb.String(), "SELECT * FROM t ") + a.Equal(ub.String(), "UPDATE t SET foo = 1 ") + a.Equal(db.String(), "DELETE FROM t ") +}