Skip to content

Commit

Permalink
escape mysql database, table, column name
Browse files Browse the repository at this point in the history
  • Loading branch information
martianzhang committed Dec 27, 2018
1 parent 431027e commit fd01635
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 95 deletions.
35 changes: 0 additions & 35 deletions ast/pretty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,38 +169,3 @@ func TestRemoveComments(t *testing.T) {
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}

func TestMysqlEscapeString(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
var strs = []map[string]string{
{
"input": "abc",
"output": "abc",
},
{
"input": "'abc",
"output": "\\'abc",
},
{
"input": `
abc`,
"output": `\
abc`,
},
{
"input": "\"abc",
"output": "\\\"abc",
},
}
for _, str := range strs {
output, err := MysqlEscapeString(str["input"])
if err != nil {
t.Error("TestMysqlEscapeString", err)
} else {
if output != str["output"] {
t.Error("TestMysqlEscapeString", output, str["output"])
}
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
46 changes: 0 additions & 46 deletions ast/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package ast

import (
"errors"
"fmt"
"regexp"
"strings"
Expand Down Expand Up @@ -614,51 +613,6 @@ func Tokenizer(sql string) []Token {
return tokens
}

// MysqlEscapeString mysql_real_escape_string
// https://github.com/liule/golang_escape
func MysqlEscapeString(source string) (string, error) {
var j = 0
if len(source) == 0 {
return "", errors.New("source is null")
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
flag := false
var escape byte
switch tempStr[i] {
case '\r':
flag = true
escape = '\r'
case '\n':
flag = true
escape = '\n'
case '\\':
flag = true
escape = '\\'
case '\'':
flag = true
escape = '\''
case '"':
flag = true
escape = '"'
case '\032':
flag = true
escape = 'Z'
default:
}
if flag {
desc[j] = '\\'
desc[j+1] = escape
j = j + 2
} else {
desc[j] = tempStr[i]
j = j + 1
}
}
return string(desc[0:j]), nil
}

// IsMysqlKeyword 判断是否是关键字
func IsMysqlKeyword(name string) bool {
_, ok := mySQLKeywords[strings.ToLower(strings.TrimSpace(name))]
Expand Down
50 changes: 49 additions & 1 deletion database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
}

// 计算该列散粒度
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb))
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb)))
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
Expand Down Expand Up @@ -319,3 +319,51 @@ func NullString(buf []byte) string {
}
return string(buf)
}

// StringEscape like C API mysql_escape_string()
// https://github.com/liule/golang_escape
func StringEscape(source string) string {
var j int
if source == "" {
return source
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
flag := false
var escape byte
switch tempStr[i] {
case '\000':
flag = true
escape = '\000'
case '\r':
flag = true
escape = '\r'
case '\n':
flag = true
escape = '\n'
case '\\':
flag = true
escape = '\\'
case '\'':
flag = true
escape = '\''
case '"':
flag = true
escape = '"'
case '\032':
flag = true
escape = 'Z'
default:
}
if flag {
desc[j] = '\\'
desc[j+1] = escape
j = j + 2
} else {
desc[j] = tempStr[i]
j = j + 1
}
}
return string(desc[0:j])
}
24 changes: 24 additions & 0 deletions database/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,27 @@ func TestNullString(t *testing.T) {
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}

func TestStringEscaple(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
cases := []string{
"",
"hello world",
"hello' world",
`hello" world`,
"hello\000world",
`hello\ world`,
"hello\032world",
"hello\rworld",
"hello\nworld",
}
err := common.GoldenDiff(func() {
for _, str := range cases {
fmt.Println(StringEscape(str))
}
}, t.Name(), update)
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
4 changes: 2 additions & 2 deletions database/sampling.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error

// startSampling sampling data from OnlineDSN to TestDSN
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", database, table, where)
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where))
common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery)
if err != nil {
Expand Down Expand Up @@ -167,7 +167,7 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", db.Database, table, colDef, values)
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values)
res, err := db.Query(query)
if res.Rows != nil {
res.Rows.Close()
Expand Down
22 changes: 11 additions & 11 deletions database/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
tbStatus := newTableStat(tableName)

// 执行 show table status
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name))
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name)))
if err != nil {
return tbStatus, err
}
Expand Down Expand Up @@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
}

// 执行 show create table
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName))
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName)

// 执行 show create table
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName))
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
// SHOW CREATE TABLE tbl_name
// SHOW CREATE TRIGGER trigger_name
// SHOW CREATE VIEW view_name
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, name))
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", StringEscape(createType), StringEscape(name)))
if err != nil {
return "", err
}
Expand Down Expand Up @@ -500,18 +500,18 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var columns []*common.Column
sql := fmt.Sprintf("SELECT "+
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name)
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name))

if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName))
}

if len(tables) > 0 {
var tmp []string
for _, table := range tables {
tmp = append(tmp, "'"+table+"'")
}
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
sql += fmt.Sprintf(" and c.table_name in (%s)", StringEscape(strings.Join(tmp, ",")))
}

common.Log.Debug("FindColumn, execute SQL: %s", sql)
Expand All @@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
// 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character

sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB)
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB))

common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult
Expand Down Expand Up @@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", tbName, dbName, column)
" COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column))

common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql)
Expand Down Expand Up @@ -604,10 +604,10 @@ type ReferenceValue struct {
func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) {
var referenceValues []ReferenceValue
sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, dbName)
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName))

if len(tbName) > 0 {
extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tbName, `","`))
extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, StringEscape(strings.Join(tbName, `","`)))
sql = sql + extra
}

Expand Down
Binary file added database/testdata/TestStringEscaple.golden
Binary file not shown.

0 comments on commit fd01635

Please sign in to comment.