Skip to content

Commit

Permalink
refs #64 add SQLServer flavor
Browse files Browse the repository at this point in the history
  • Loading branch information
huandu committed Jul 24, 2021
1 parent 02da50f commit f1b7ac4
Show file tree
Hide file tree
Showing 8 changed files with 427 additions and 168 deletions.
4 changes: 3 additions & 1 deletion args.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interfac
case MySQL, SQLite:
buf.WriteRune('?')
case PostgreSQL:
fmt.Fprintf(buf, "$%v", len(values)+1)
fmt.Fprintf(buf, "$%d", len(values)+1)
case SQLServer:
fmt.Fprintf(buf, "@p%d", len(values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
}
Expand Down
33 changes: 32 additions & 1 deletion args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ func TestArgs(t *testing.T) {
}

old := DefaultFlavor
DefaultFlavor = PostgreSQL
defer func() {
DefaultFlavor = old
}()

DefaultFlavor = PostgreSQL

// PostgreSQL flavor compiled sql.
for expected, c := range cases {
args := new(Args)
Expand All @@ -59,6 +60,23 @@ func TestArgs(t *testing.T) {

a.Equal(actual, expected)
}

DefaultFlavor = SQLServer

// SQLServer flavor compiled sql.
for expected, c := range cases {
args := new(Args)

for i := 1; i < len(c); i++ {
args.Add(c[i])
}

sql, values := args.Compile(c[0].(string))
actual := fmt.Sprintf("%v\n%v", sql, values)
expected = toSQLServerSQL(expected)

a.Equal(actual, expected)
}
}

func toPostgreSQL(sql string) string {
Expand All @@ -74,6 +92,19 @@ func toPostgreSQL(sql string) string {
return buf.String()
}

func toSQLServerSQL(sql string) string {
parts := strings.Split(sql, "?")
buf := &bytes.Buffer{}
buf.WriteString(parts[0])

for i, p := range parts[1:] {
fmt.Fprintf(buf, "@p%v", i+1)
buf.WriteString(p)
}

return buf.String()
}

func TestArgsAdd(t *testing.T) {
a := assert.New(t)
args := &Args{}
Expand Down
11 changes: 8 additions & 3 deletions flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
MySQL
PostgreSQL
SQLite
SQLServer
)

var (
Expand Down Expand Up @@ -46,6 +47,8 @@ func (f Flavor) String() string {
return "PostgreSQL"
case SQLite:
return "SQLite"
case SQLServer:
return "SQLServer"
}

return "<invalid>"
Expand All @@ -64,6 +67,8 @@ func (f Flavor) Interpolate(sql string, args []interface{}) (string, error) {
return postgresqlInterpolate(sql, args...)
case SQLite:
return sqliteInterpolate(sql, args...)
case SQLServer:
return sqlserverInterpolate(sql, args...)
}

return "", ErrInterpolateNotImplemented
Expand Down Expand Up @@ -114,13 +119,13 @@ func (f Flavor) NewUnionBuilder() *UnionBuilder {
// Quote adds quote for name to make sure the name can be used safely
// as table name or field name.
//
// * For MySQL, use back quote (`) to quote name;
// * For PostgreSQL and SQLite, use double quote (") to quote name.
// * For MySQL, use back quote (`) to quote name;
// * For PostgreSQL, SQL Server and SQLite, use double quote (") to quote name.
func (f Flavor) Quote(name string) string {
switch f {
case MySQL:
return fmt.Sprintf("`%s`", name)
case PostgreSQL, SQLite:
case PostgreSQL, SQLServer, SQLite:
return fmt.Sprintf(`"%s"`, name)
}

Expand Down
135 changes: 19 additions & 116 deletions flavor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ package sqlbuilder

import (
"fmt"
"strconv"
"testing"
"time"

"github.com/huandu/go-assert"
)
Expand All @@ -19,6 +17,7 @@ func TestFlavor(t *testing.T) {
MySQL: "MySQL",
PostgreSQL: "PostgreSQL",
SQLite: "SQLite",
SQLServer: "SQLServer",
}

for f, expected := range cases {
Expand All @@ -27,120 +26,6 @@ func TestFlavor(t *testing.T) {
}
}

func TestFlavorInterpolate(t *testing.T) {
a := assert.New(t)
dt := time.Date(2019, 4, 24, 12, 23, 34, 123456789, time.FixedZone("CST", 8*60*60)) // 2019-04-24 12:23:34.987654321 CST
_, errOutOfRange := strconv.ParseInt("12345678901234567890", 10, 32)
cases := []struct {
flavor Flavor
sql string
args []interface{}
query string
err error
}{
{
MySQL,
"SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)},
"SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil,
},
{
MySQL,
"SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"},
"SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil,
},
{
MySQL,
"SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil},
"SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, _binary'I\\'m bytes', '2019-04-24 12:23:34.123457', '0000-00-00', NULL", nil,
},
{
MySQL,
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{MySQL},
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\'MySQL'", nil,
},
{
MySQL,
"SELECT ?", nil,
"", ErrInterpolateMissingArgs,
},
{
MySQL,
"SELECT ?", []interface{}{complex(1, 2)},
"", ErrInterpolateUnsupportedArgs,
},

{
PostgreSQL,
"SELECT * FROM a WHERE name = $3 AND state IN ($2, $4, $1, $6, $5)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)},
"SELECT * FROM a WHERE name = 8 AND state IN (42, -16, E'I\\'m fine', 64, 32)", nil,
},
{
PostgreSQL,
"SELECT * FROM $abc$$1$abc$1$1 WHERE name = \"$1\" AND state IN ($2, '$1', $3, $6, $5, $4, $2) $3", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"},
"SELECT * FROM $abc$$1$abc$1E'\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'' WHERE name = \"$1\" AND state IN (42, '$1', 8, 64, 32, 16, 42) 8", nil,
},
{
PostgreSQL,
"SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $11, $a", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil, 10, 11, 12},
"SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, E'\\\\x49276D206279746573'::bytea, '2019-04-24 12:23:34.123457 CST', '0000-00-00', NULL, 11, $a", nil,
},
{
PostgreSQL,
"SELECT '\\'$1', \"\\\"$1\", `$1`, \\$1a, $$1$$, $a $b$ $a $ $1$b$1$1 $a$ $", []interface{}{MySQL},
"SELECT '\\'$1', \"\\\"$1\", `E'MySQL'`, \\E'MySQL'a, $$1$$, $a $b$ $a $ $1$b$1E'MySQL' $a$ $", nil,
},
{
PostgreSQL,
"SELECT * FROM a WHERE name = 'Huan''Du''$1' AND desc = $1", []interface{}{"c'mon"},
"SELECT * FROM a WHERE name = 'Huan''Du''$1' AND desc = E'c\\'mon'", nil,
},
{
PostgreSQL,
"SELECT $1", nil,
"", ErrInterpolateMissingArgs,
},
{
PostgreSQL,
"SELECT $1", []interface{}{complex(1, 2)},
"", ErrInterpolateUnsupportedArgs,
},
{
PostgreSQL,
"SELECT $12345678901234567890", nil,
"", errOutOfRange,
},

{
SQLite,
"SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)},
"SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil,
},
{
SQLite,
"SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"},
"SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil,
},
{
SQLite,
"SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil},
"SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, X'49276D206279746573', '2019-04-24 12:23:34.123', '0000-00-00', NULL", nil,
},
{
SQLite,
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{SQLite},
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\'SQLite'", nil,
},
}

for idx, c := range cases {
a.Use(&idx, &c)
query, err := c.flavor.Interpolate(c.sql, c.args)

a.Equal(query, c.query)
a.Assert(err == c.err || err.Error() == c.err.Error())
}
}

func ExampleFlavor() {
// Create a flavored builder.
sb := PostgreSQL.NewSelectBuilder()
Expand Down Expand Up @@ -218,3 +103,21 @@ func ExampleFlavor_Interpolate_sqlite() {
// SELECT name FROM user WHERE id <> 1234 AND name = 'Charmy Liu' AND desc LIKE '%mother\'s day%'
// <nil>
}

func ExampleFlavor_Interpolate_sqlServer() {
sb := SQLServer.NewSelectBuilder()
sb.Select("name").From("user").Where(
sb.NE("id", 1234),
sb.E("name", "Charmy Liu"),
sb.Like("desc", "%mother's day%"),
)
sql, args := sb.Build()
query, err := SQLServer.Interpolate(sql, args)

fmt.Println(query)
fmt.Println(err)

// Output:
// SELECT name FROM user WHERE id <> 1234 AND name = N'Charmy Liu' AND desc LIKE N'%mother\'s day%'
// <nil>
}
Loading

0 comments on commit f1b7ac4

Please sign in to comment.