Skip to content

Commit

Permalink
fix #6. support Postgre-flavor markers.
Browse files Browse the repository at this point in the history
  • Loading branch information
huandu committed Feb 4, 2018
1 parent 0bc780f commit 4414789
Show file tree
Hide file tree
Showing 13 changed files with 341 additions and 40 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@ Following builders are implemented right now. API document and examples are prov
* [Build](https://godoc.org/github.com/huandu/go-sqlbuilder#Build): Advanced freestyle builder using special syntax defined in [Args#Compile](https://godoc.org/github.com/huandu/go-sqlbuilder#Args.Compile).
* [BuildNamed](https://godoc.org/github.com/huandu/go-sqlbuilder#BuildNamed): Advanced freestyle builder using `${key}` to refer the value of a map by key.

### Build SQL for MySQL or PostgreSQL ###

Parameter markers are different in MySQL and PostgreSQL. This package provides some methods to set the type of markers (we call it "flavor") in all builders.

By default, all builders uses `DefaultFlavor` to build SQL. The default value is `MySQL`.

There is a `BuildWithFlavor` method in `Builder` interface. We can use it to build a SQL with provided flavor.

We can wrap any `Builder` with a default flavor through `WithFlavor`.

To be more verbose, we can use `PostgreSQL.NewSelectBuilder()` to create a `SelectBuilder` with the `PostgreSQL` flavor. All builders can be created in this way.

Right now, there are only two flavors, `MySQL` and `PostgreSQL`. Open new issue to me to ask for a new flavor if you find it necessary.

### Using `Struct` as a light weight ORM ###

`Struct` stores type information and struct fields of a struct. It's a factory of builders. We can use `Struct` methods to create initialized SELECT/INSERT/UPDATE/DELETE builders to work with the struct. It can help us to save time and avoid human-error on writing column names.
Expand Down
52 changes: 37 additions & 15 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (

// Args stores arguments associated with a SQL.
type Args struct {
// The default flavor used by `Args#Compile`
Flavor Flavor

args []interface{}
namedArgs map[string]int
sqlNamedArgs map[string]int
Expand Down Expand Up @@ -60,7 +63,7 @@ func (args *Args) add(arg interface{}) int {
return idx
}

// Compile analyzes builder's format to standard sql and returns associated args.
// Compile compiles builder's format to standard sql and returns associated args.
//
// The format string uses a special syntax to represent arguments.
//
Expand All @@ -69,10 +72,21 @@ func (args *Args) add(arg interface{}) int {
// ${name} refers a named argument created by `Named` with `name`.
// $$ is a "$" string.
func (args *Args) Compile(format string) (query string, values []interface{}) {
return args.CompileWithFlavor(format, args.Flavor)
}

// CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args.
//
// See doc for `Compile` to learn details.
func (args *Args) CompileWithFlavor(format string, flavor Flavor) (query string, values []interface{}) {
buf := &bytes.Buffer{}
idx := strings.IndexRune(format, '$')
offset := 0

if flavor == invalidFlavor {
flavor = DefaultFlavor
}

for idx >= 0 && len(format) > 0 {
if idx > 0 {
buf.WriteString(format[:idx])
Expand All @@ -89,11 +103,11 @@ func (args *Args) Compile(format string) (query string, values []interface{}) {
buf.WriteRune('$')
format = format[1:]
} else if format[0] == '{' {
format, values = args.compileNamed(buf, format, values)
format, values = args.compileNamed(buf, flavor, format, values)
} else if !args.onlyNamed && '0' <= format[0] && format[0] <= '9' {
format, values, offset = args.compileDigits(buf, format, values, offset)
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
} else if !args.onlyNamed && format[0] == '?' {
format, values, offset = args.compileSuccessive(buf, format[1:], values, offset)
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
}

idx = strings.IndexRune(format, '$')
Expand Down Expand Up @@ -123,7 +137,7 @@ func (args *Args) Compile(format string) (query string, values []interface{}) {
return
}

func (args *Args) compileNamed(buf *bytes.Buffer, format string, values []interface{}) (string, []interface{}) {
func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
i := 1

for ; i < len(format) && format[i] != '}'; i++ {
Expand All @@ -139,13 +153,13 @@ func (args *Args) compileNamed(buf *bytes.Buffer, format string, values []interf
format = format[i+1:]

if p, ok := args.namedArgs[name]; ok {
format, values, _ = args.compileSuccessive(buf, format, values, p)
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
}

return format, values
}

func (args *Args) compileDigits(buf *bytes.Buffer, format string, values []interface{}, offset int) (string, []interface{}, int) {
func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
i := 1

for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
Expand All @@ -156,27 +170,27 @@ func (args *Args) compileDigits(buf *bytes.Buffer, format string, values []inter
format = format[i:]

if pointer, err := strconv.Atoi(digits); err == nil {
return args.compileSuccessive(buf, format, values, pointer)
return args.compileSuccessive(buf, flavor, format, values, pointer)
}

return format, values, offset
}

func (args *Args) compileSuccessive(buf *bytes.Buffer, format string, values []interface{}, offset int) (string, []interface{}, int) {
func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
if offset >= len(args.args) {
return format, values, offset
}

arg := args.args[offset]
values = args.compileArg(buf, values, arg)
values = args.compileArg(buf, flavor, values, arg)

return format, values, offset + 1
}

func (args *Args) compileArg(buf *bytes.Buffer, values []interface{}, arg interface{}) []interface{} {
func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
switch a := arg.(type) {
case Builder:
s, nestedArgs := a.Build()
s, nestedArgs := a.BuildWithFlavor(flavor)
buf.WriteString(s)
values = append(values, nestedArgs...)
case sql.NamedArg:
Expand All @@ -186,15 +200,23 @@ func (args *Args) compileArg(buf *bytes.Buffer, values []interface{}, arg interf
buf.WriteString(a.expr)
case listArgs:
if len(a.args) > 0 {
values = args.compileArg(buf, values, a.args[0])
values = args.compileArg(buf, flavor, values, a.args[0])
}

for i := 1; i < len(a.args); i++ {
buf.WriteString(", ")
values = args.compileArg(buf, values, a.args[i])
values = args.compileArg(buf, flavor, values, a.args[i])
}
default:
buf.WriteRune('?')
switch flavor {
case MySQL:
buf.WriteRune('?')
case PostgreSQL:
fmt.Fprintf(buf, "$%v", len(values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
}

values = append(values, arg)
}

Expand Down
38 changes: 38 additions & 0 deletions args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package sqlbuilder

import (
"bytes"
"fmt"
"strings"
"testing"
)

Expand Down Expand Up @@ -35,4 +37,40 @@ func TestArgs(t *testing.T) {
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
}
}

old := DefaultFlavor
DefaultFlavor = PostgreSQL
defer func() {
DefaultFlavor = old
}()

// PostgreSQL flavor compiled sql.
for expected, c := range cases {
args := new(Args)

for i := 1; i < len(c); i++ {
args.Add(c[i])
}

sql, values := args.Compile(c[0].(string))
actual := fmt.Sprintf("%v\n%v", sql, values)
expected = toPostgreSQL(expected)

if actual != expected {
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
}
}
}

func toPostgreSQL(sql string) string {
parts := strings.Split(sql, "?")
buf := &bytes.Buffer{}
buf.WriteString(parts[0])

for i, p := range parts[1:] {
fmt.Fprintf(buf, "$%v", i+1)
buf.WriteString(p)
}

return buf.String()
}
59 changes: 42 additions & 17 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,94 @@ import (
// `SELECT * FROM t1 WHERE id IN (SELECT id FROM t2)`.
type Builder interface {
Build() (sql string, args []interface{})
BuildWithFlavor(flavor Flavor) (sql string, args []interface{})
}

type compiledBuilder struct {
sql string
args []interface{}
args *Args
format string
}

func (cb *compiledBuilder) Build() (sql string, args []interface{}) {
return cb.sql, cb.args
return cb.args.Compile(cb.format)
}

func (cb *compiledBuilder) BuildWithFlavor(flavor Flavor) (sql string, args []interface{}) {
return cb.args.CompileWithFlavor(cb.format, flavor)
}

type flavoredBuilder struct {
builder Builder
flavor Flavor
}

func (fb *flavoredBuilder) Build() (sql string, args []interface{}) {
return fb.builder.BuildWithFlavor(fb.flavor)
}

func (fb *flavoredBuilder) BuildWithFlavor(flavor Flavor) (sql string, args []interface{}) {
return fb.builder.BuildWithFlavor(flavor)
}

// WithFlavor creates a new Builder based on builder with a default flavor.
func WithFlavor(builder Builder, flavor Flavor) Builder {
return &flavoredBuilder{
builder: builder,
flavor: flavor,
}
}

// Buildf creates a Builder from a format string using `fmt.Sprintf`-like syntax.
// As all arguments will be converted to a string internally, e.g. "$0",
// only `%v` and `%s` are valid.
func Buildf(format string, arg ...interface{}) Builder {
args := &Args{}
args := &Args{
Flavor: DefaultFlavor,
}
vars := make([]interface{}, 0, len(arg))

for _, a := range arg {
vars = append(vars, args.Add(a))
}

format = Escape(format)
str := fmt.Sprintf(format, vars...)
sql, values := args.Compile(str)

return &compiledBuilder{
sql: sql,
args: values,
args: args,
format: fmt.Sprintf(Escape(format), vars...),
}
}

// Build creates a Builder from a format string.
// The format string uses special syntax to represent arguments.
// See doc in `Args#Compile` for syntax details.
func Build(format string, arg ...interface{}) Builder {
args := &Args{}
args := &Args{
Flavor: DefaultFlavor,
}

for _, a := range arg {
args.Add(a)
}

sql, values := args.Compile(format)
return &compiledBuilder{
sql: sql,
args: values,
args: args,
format: format,
}
}

// BuildNamed creates a Builder from a format string.
// The format string uses `${key}` to refer the value of named by key.
func BuildNamed(format string, named map[string]interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
onlyNamed: true,
}

for n, v := range named {
args.Add(Named(n, v))
}

sql, values := args.Compile(format)
return &compiledBuilder{
sql: sql,
args: values,
args: args,
format: format,
}
}
11 changes: 11 additions & 0 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ func ExampleBuildNamed() {
// SELECT * FROM user WHERE status IN (?, ?, ?) AND name LIKE ? AND created_at > @start AND modified_at < @start + 86400
// [1 2 5 Huan% {{} start 1234567890}]
}

func ExampleWithFlavor() {
sql, args := WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).Build()

fmt.Println(sql)
fmt.Println(args)

// Output:
// SELECT * FROM foo WHERE id = $1
// [1234]
}
19 changes: 18 additions & 1 deletion delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (

// NewDeleteBuilder creates a new DELETE builder.
func NewDeleteBuilder() *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder()
}

func newDeleteBuilder() *DeleteBuilder {
args := &Args{}
return &DeleteBuilder{
Cond: Cond{
Expand Down Expand Up @@ -50,6 +54,12 @@ func (db *DeleteBuilder) String() string {
// Build returns compiled DELETE string and args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (db *DeleteBuilder) Build() (sql string, args []interface{}) {
return db.BuildWithFlavor(db.args.Flavor)
}

// BuildWithFlavor returns compiled DELETE string and args with flavor.
// They can be used in `DB#Query` of package `database/sql` directly.
func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor) (sql string, args []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
buf.WriteString(db.table)
Expand All @@ -59,5 +69,12 @@ func (db *DeleteBuilder) Build() (sql string, args []interface{}) {
buf.WriteString(strings.Join(db.whereExprs, " AND "))
}

return db.args.Compile(buf.String())
return db.args.CompileWithFlavor(buf.String(), flavor)
}

// SetFlavor sets the flavor of compiled sql.
func (db *DeleteBuilder) SetFlavor(flavor Flavor) (old Flavor) {
old = db.args.Flavor
db.args.Flavor = flavor
return
}
Loading

0 comments on commit 4414789

Please sign in to comment.