Skip to content

Commit

Permalink
string escape no_backslash_escapes
Browse files Browse the repository at this point in the history
  • Loading branch information
martianzhang committed Dec 28, 2018
1 parent 84a6702 commit 59094bf
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 23 deletions.
45 changes: 39 additions & 6 deletions database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package database

import (
"bytes"
"database/sql"
"errors"
"fmt"
"io"
"regexp"
"strconv"
"strings"
Expand All @@ -39,7 +41,6 @@ type Connector struct {
Pass string
Database string
Charset string
Net string
Conn *sql.DB
}

Expand Down Expand Up @@ -202,7 +203,11 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
}

// 计算该列散粒度
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb)))
db.Conn.Stats()
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`",
Escape(col, false),
Escape(db.Database, false),
Escape(tb, false)))
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
Expand Down Expand Up @@ -313,19 +318,37 @@ func NullString(buf []byte) string {
return string(buf)
}

// StringEscape like C API mysql_escape_string()
// quoteEscape sql_mode=no_backslash_escapes
func quoteEscape(source string) string {
var buf bytes.Buffer
last := 0
for ii, bb := range source {
if bb == '\'' {
_, err := io.WriteString(&buf, source[last:ii])
common.LogIfWarn(err, "")
_, err = io.WriteString(&buf, `''`)
common.LogIfWarn(err, "")
last = ii + 1
}
}
_, err := io.WriteString(&buf, source[last:])
common.LogIfWarn(err, "")
return buf.String()
}

// stringEscape mysql_escape_string
// https://github.com/liule/golang_escape
func StringEscape(source string) string {
func stringEscape(source string) string {
var j int
if source == "" {
return source
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
for i, b := range tempStr {
flag := false
var escape byte
switch tempStr[i] {
switch b {
case '\000':
flag = true
escape = '\000'
Expand Down Expand Up @@ -360,3 +383,13 @@ func StringEscape(source string) string {
}
return string(desc[0:j])
}

// Escape like C API mysql_escape_string()
func Escape(source string, NoBackslashEscapes bool) string {
// NoBackslashEscapes https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_backslash_escapes
// TODO: NoBackslashEscapes always false
if NoBackslashEscapes {
return quoteEscape(source)
}
return stringEscape(source)
}
5 changes: 3 additions & 2 deletions database/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func TestNullString(t *testing.T) {
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}

func TestStringEscaple(t *testing.T) {
func TestEscape(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
cases := []string{
"",
Expand All @@ -214,7 +214,8 @@ func TestStringEscaple(t *testing.T) {
}
err := common.GoldenDiff(func() {
for _, str := range cases {
fmt.Println(StringEscape(str))
fmt.Println(Escape(str, false))
fmt.Println(Escape(str, true))
}
}, t.Name(), update)
if err != nil {
Expand Down
17 changes: 13 additions & 4 deletions database/sampling.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error

// startSampling sampling data from OnlineDSN to TestDSN
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where))
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s",
Escape(database, false),
Escape(table, false),
Escape(where, false))
common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery)
if err != nil {
Expand Down Expand Up @@ -136,8 +139,11 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
switch columnTypes[i].DatabaseTypeName() {
case "TIMESTAMP", "DATETIME":
t, err := time.Parse(time.RFC3339, string(val))
common.LogIfWarn(err, "")
values = append(values, fmt.Sprintf(`"%s"`, TimeString(t)))
if err != nil {
values = append(values, fmt.Sprintf(`"%s"`, string(val)))
} else {
values = append(values, fmt.Sprintf(`"%s"`, TimeString(t)))
}
default:
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
}
Expand Down Expand Up @@ -167,7 +173,10 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values)
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;",
Escape(db.Database, false),
Escape(table, false),
Escape(colDef, false), values)
res, err := db.Query(query)
if res.Rows != nil {
res.Rows.Close()
Expand Down
22 changes: 11 additions & 11 deletions database/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
tbStatus := newTableStat(tableName)

// 执行 show table status
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name)))
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", Escape(tbStatus.Name, false)))
if err != nil {
return tbStatus, err
}
Expand Down Expand Up @@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
}

// 执行 show create table
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false)))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName)

// 执行 show create table
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false)))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
// SHOW CREATE TABLE tbl_name
// SHOW CREATE TRIGGER trigger_name
// SHOW CREATE VIEW view_name
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, StringEscape(name)))
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, Escape(name, false)))
if err != nil {
return "", err
}
Expand Down Expand Up @@ -500,16 +500,16 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var columns []*common.Column
sql := fmt.Sprintf("SELECT "+
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name))
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", Escape(name, false))

if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName))
sql += fmt.Sprintf(" and c.table_schema = '%s'", Escape(dbName, false))
}

if len(tables) > 0 {
var tmp []string
for _, table := range tables {
tmp = append(tmp, "'"+StringEscape(table)+"'")
tmp = append(tmp, "'"+Escape(table, false)+"'")
}
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
}
Expand Down Expand Up @@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
// 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character

sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB))
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", Escape(col.Table, false), Escape(col.DB, false))

common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult
Expand Down Expand Up @@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column))
" COLUMN_NAME='%s'", Escape(tbName, false), Escape(dbName, false), Escape(column, false))

common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql)
Expand Down Expand Up @@ -604,11 +604,11 @@ type ReferenceValue struct {
func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) {
var referenceValues []ReferenceValue
sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName))
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, Escape(dbName, false))

var tables []string
for _, tb := range tbName {
tables = append(tables, "'"+StringEscape(tb)+"'")
tables = append(tables, "'"+Escape(tb, false)+"'")
}
if len(tbName) > 0 {
extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tables, ","))
Expand Down
Binary file not shown.

0 comments on commit 59094bf

Please sign in to comment.