Skip to content

Commit

Permalink
sampling data type trading
Browse files Browse the repository at this point in the history
  • Loading branch information
martianzhang committed Dec 24, 2018
1 parent 86da258 commit 62f0b8e
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 125 deletions.
3 changes: 3 additions & 0 deletions common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type Configuration struct {
OnlySyntaxCheck bool `yaml:"only-syntax-check"` // 只做语法检查不输出优化建议
SamplingStatisticTarget int `yaml:"sampling-statistic-target"` // 数据采样因子,对应 PostgreSQL 的 default_statistics_target
Sampling bool `yaml:"sampling"` // 数据采样开关
SamplingCondition string `yaml:"sampling-condition"` // 指定采样条件,如:WHERE xxx LIMIT xxx;
Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile
Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace
Explain bool `yaml:"explain"` // Explain开关
Expand Down Expand Up @@ -506,6 +507,7 @@ func readCmdFlags() error {
explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析")
sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关")
samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target")
samplingCondition := flag.String("sampling-condition", Config.SamplingCondition, "SamplingCondition, 数据采样条件,如: WHERE xxx LIMIT xxx")
delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符")
// +++++++++++++++日志相关+++++++++++++++++
logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]")
Expand Down Expand Up @@ -585,6 +587,7 @@ func readCmdFlags() error {
Config.Explain = *explain
Config.Sampling = *sampling
Config.SamplingStatisticTarget = *samplingStatisticTarget
Config.SamplingCondition = *samplingCondition

Config.LogLevel = *logLevel
if strings.HasPrefix(*logOutput, "/") {
Expand Down
137 changes: 75 additions & 62 deletions database/sampling.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
package database

import (
"database/sql"
"fmt"
"time"

"github.com/XiaoMi/soar/common"
"strings"

"database/sql"
"github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql"
)

/*--------------------
Expand All @@ -44,99 +47,109 @@ import (
*--------------------
*/

// SamplingData 将数据从Remote拉取到 db 中
func (db *Connector) SamplingData(remote *Connector, tables ...string) error {
// SamplingData 将数据从 onlineConn 拉取到 db 中
func (db *Connector) SamplingData(onlineConn *Connector, database string, tables ...string) error {
var err error
if database == db.Database {
return fmt.Errorf("SamplingData the same database, From: %s/%s, To: %s/%s", onlineConn.Addr, database, db.Addr, db.Database)
}

// 计算需要泵取的数据量
wantRowsCount := 300 * common.Config.SamplingStatisticTarget

// 设置数据采样单条 SQL 中 value 的数量
// 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快
maxValCount := 200

for _, table := range tables {
// 表类型检查
if remote.IsView(table) {
if onlineConn.IsView(table) {
return nil
}

tableStatus, err := remote.ShowTableStatus(table)
if err != nil {
return err
}

if len(tableStatus.Rows) == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
// generate where condition
var where string
if common.Config.SamplingCondition == "" {
tableStatus, err := onlineConn.ShowTableStatus(table)
if err != nil {
return err
}

if len(tableStatus.Rows) == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}

tableRows := tableStatus.Rows[0].Rows
if tableRows == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}

factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
where = fmt.Sprintf("WHERE RAND() <= %f LIMIT %d", factor, wantRowsCount)
if factor >= 1 {
where = ""
}
} else {
where = common.Config.SamplingCondition
}

tableRows := tableStatus.Rows[0].Rows
if tableRows == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}

factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)

err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount)
if err != nil {
common.Log.Error("(db *Connector) SamplingData Error : %v", err)
}
err = db.startSampling(onlineConn.Conn, database, table, where)
}
return nil
return err
}

// startSampling sampling data from OnlineDSN to TestDSN
// 因为涉及到的数据量问题,所以泵取与插入时同时进行的
// TODO: 加 ref link
func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error {
// generate where condition
where := fmt.Sprintf("WHERE RAND() <= %f", factor)
if factor >= 1 {
where = ""
}

res, err := conn.Query(fmt.Sprintf("SELECT * FROM `%s`.`%s` %s LIMIT %d;", database, table, where, wants))
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("SELECT * FROM `%s`.`%s` %s", database, table, where)
common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery)
if err != nil {
return err
}

// column info
// columns list
columns, err := res.Columns()
if err != nil {
return err
}
row := make(map[string][]byte, len(columns))
row := make([][]byte, len(columns))
tableFields := make([]interface{}, 0)
for _, col := range columns {
if _, ok := row[col]; ok {
tableFields = append(tableFields, row[col])
}
for i := range columns {
tableFields = append(tableFields, &row[i])
}
columnTypes, err := res.ColumnTypes()
if err != nil {
return err
}

// sampling data
var valuesStr string
var values []string
columnsStr := "`" + strings.Join(columns, "`,`") + "`"
for res.Next() {
res.Scan(tableFields...)
for _, val := range row {
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
for i, val := range row {
if val == nil {
values = append(values, "NULL")
} else {
switch columnTypes[i].DatabaseTypeName() {
case "TIMESTAMP", "DATETIME":
t, err := time.Parse(time.RFC3339, string(val))
common.LogIfWarn(err, "")
values = append(values, fmt.Sprintf(`"%s"`, mysql.TimeString(t)))
default:
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
}
}
}
valuesStr = fmt.Sprintf(`(%s)`, strings.Join(values, `,`))
doSampling(localConn, database, table, columnsStr, valuesStr)
err = db.doSampling(table, columnsStr, strings.Join(values, `,`))
}
res.Close()
return nil
return err
}

// 将泵取的数据转换成Insert语句并在数据库中执行
func doSampling(conn *sql.DB, dbName, table, colDef, values string) {
query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", dbName, table,
colDef, values)

_, err := conn.Exec(query)
if err != nil {
common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err)
}
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name
query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES (%s);", db.Database, table, colDef, values)
_, err := db.Query(query)
return err
}
39 changes: 0 additions & 39 deletions database/sampling_test.go

This file was deleted.

34 changes: 17 additions & 17 deletions database/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,27 +459,27 @@ func (db *Connector) ShowCreateTable(tableName string) (string, error) {
ddl, err := db.showCreate("table", tableName)

// 去除外键关联条件
var noConstraint []string
relationReg, _ := regexp.Compile("CONSTRAINT")
for _, line := range strings.Split(ddl, "\n") {

if relationReg.Match([]byte(line)) {
continue
}

// 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除
if strings.Index(line, ")") == 0 {
lineWrongSyntax := noConstraint[len(noConstraint)-1]
// 如果')'前一句的末尾是',' 删除 ',' 保证语法正确性
if strings.Index(lineWrongSyntax, ",") == len(lineWrongSyntax)-1 {
noConstraint[len(noConstraint)-1] = lineWrongSyntax[:len(lineWrongSyntax)-1]
lines := strings.Split(ddl, "\n")
// CREATE VIEW ONLY 1 LINE
if len(lines) > 2 {
var noConstraint []string
relationReg, _ := regexp.Compile("CONSTRAINT")
for _, line := range lines[1 : len(lines)-1] {
if relationReg.Match([]byte(line)) {
continue
}
line = strings.TrimSuffix(line, ",")
noConstraint = append(noConstraint, line)
}

noConstraint = append(noConstraint, line)
// 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除
ddl = fmt.Sprint(
lines[0], "\n",
strings.Join(noConstraint, ",\n"), "\n",
lines[len(lines)-1],
)
}

return strings.Join(noConstraint, "\n"), err
return ddl, err
}

// FindColumn find column
Expand Down
2 changes: 2 additions & 0 deletions database/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ func TestShowCreateTable(t *testing.T) {
connTest.Database = "sakila"
tables := []string{
"film",
"category",
"customer_list",
"inventory",
}
err := common.GoldenDiff(func() {
for _, table := range tables {
Expand Down
15 changes: 15 additions & 0 deletions database/testdata/TestShowCreateTable.golden
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,19 @@ CREATE TABLE `film` (
KEY `idx_fk_language_id` (`language_id`),
KEY `idx_fk_original_language_id` (`original_language_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8
CREATE TABLE `category` (
`category_id` tinyint(3) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(25) NOT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`category_id`)
) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8
CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`)))
CREATE TABLE `inventory` (
`inventory_id` mediumint(8) unsigned NOT NULL AUTO_INCREMENT,
`film_id` smallint(5) unsigned NOT NULL,
`store_id` tinyint(3) unsigned NOT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`inventory_id`),
KEY `idx_fk_film_id` (`film_id`),
KEY `idx_store_id_film_id` (`store_id`,`film_id`)
) ENGINE=InnoDB AUTO_INCREMENT=4582 DEFAULT CHARSET=utf8
10 changes: 3 additions & 7 deletions env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,21 +453,17 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string
res, err := ve.Query(ddl)
if err != nil {
// 有可能是用户新建表,因此线上环境查不到
common.Log.Error("createTable, %s Error : %v", tbName, err)
common.Log.Error("createTable: %s Error : %v", tbName, err)
return err
}
res.Rows.Close()

// 泵取数据
if common.Config.Sampling {
common.Log.Debug("createTable, Start Sampling data from %s.%s to %s.%s ...", dbName, tbName, ve.DBRef[dbName], tbName)
err := ve.SamplingData(rEnv, tbName)
if err != nil {
common.Log.Error(" (ve VirtualEnv) createTable SamplingData Error: %v", err)
return err
}
err = ve.SamplingData(rEnv, dbName, tbName)
}
return nil
return err
}

// GenTableColumns 为 Rewrite 提供的结构体初始化
Expand Down
Loading

0 comments on commit 62f0b8e

Please sign in to comment.