Skip to content

Commit

Permalink
use go-sql-driver's DSN format
Browse files Browse the repository at this point in the history
  compatibly with old version(<0.10.0)
  • Loading branch information
martianzhang committed Dec 28, 2018
1 parent ba091b0 commit 84a6702
Show file tree
Hide file tree
Showing 7 changed files with 791 additions and 84 deletions.
165 changes: 121 additions & 44 deletions common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import (
"regexp"
"runtime"
"strings"
"time"

"github.com/go-sql-driver/mysql"
"gopkg.in/yaml.v2"
)

Expand Down Expand Up @@ -138,20 +140,8 @@ type Configuration struct {

// Config 默认设置
var Config = &Configuration{
OnlineDSN: &Dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
Version: 99999,
},
TestDSN: &Dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
Version: 99999,
},
OnlineDSN: newDSN(nil),
TestDSN: newDSN(nil),
AllowOnlineAsTest: false,
DropTestTemporary: true,
CleanupTestDatabase: false,
Expand Down Expand Up @@ -236,25 +226,104 @@ var Config = &Configuration{

// Dsn Data source name
type Dsn struct {
Net string `yaml:"net"`
Addr string `yaml:"addr"`
Schema string `yaml:"schema"`

// 数据库用户名和密码可以通过系统环境变量的形式赋值
User string `yaml:"user"`
Password string `yaml:"password"`
Charset string `yaml:"charset"`
Disable bool `yaml:"disable"`
User string `yaml:"user"` // Usernames
Password string `yaml:"password"` // Password (requires User)
Net string `yaml:"net"` // Network type
Addr string `yaml:"addr"` // Network address (requires Net)
Schema string `yaml:"schema"` // Database name
Charset string `yaml:"charset"` // SET NAMES charset
Collation string `yaml:"collation"` // Connection collation
Loc string `yaml:"loc"` // Location for time.Time values
TLS string `yaml:"tls"` // TLS configuration name
ServerPubKey string `yaml:"server-public-key"` // Server public key name
MaxAllowedPacket int `ymal:"max-allowed-packet"` // Max packet size allowed
Params map[string]string `yaml:"params"` // Other Connection parameters, `SET param=val`, `SET NAMES charset`
Timeout int `yaml:"timeout"` // Dial timeout
ReadTimeout int `yaml:"read-timeout"` // I/O read timeout
WriteTimeout int `yaml:"write-timeout"` // I/O write timeout

AllowNativePasswords bool `yaml:"allow-native-passwords"` // Allows the native password authentication method
AllowOldPasswords bool `yaml:"allow-old-passwords"` // Allows the old insecure password method

Disable bool `yaml:"disable"`
Version int `yaml:"-"` // 版本自动检查,不可配置
}

Timeout int `yaml:"timeout"`
ReadTimeout int `yaml:"read-timeout"`
WriteTimeout int `yaml:"write-timeout"`
// newDSN create default Dsn struct
func newDSN(cfg *mysql.Config) *Dsn {
dsn := &Dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8",
AllowNativePasswords: true,
Params: make(map[string]string),
MaxAllowedPacket: 4 << 20, // 4 MiB

// Disable: true,
Version: 99999,
}
if cfg == nil {
return dsn
}
dsn.User = cfg.User
dsn.Password = cfg.Passwd
dsn.Net = cfg.Net
dsn.Addr = cfg.Addr
dsn.Schema = cfg.DBName
dsn.Params = make(map[string]string)
for k, v := range cfg.Params {
dsn.Params[k] = v
}
if _, ok := cfg.Params["charset"]; ok {
dsn.Charset = cfg.Params["charset"]
}
dsn.Collation = cfg.Collation
dsn.Loc = cfg.Loc.String()
dsn.MaxAllowedPacket = cfg.MaxAllowedPacket
dsn.ServerPubKey = cfg.ServerPubKey
dsn.TLS = cfg.TLSConfig
dsn.Timeout = int(cfg.Timeout / time.Second)
dsn.ReadTimeout = int(cfg.ReadTimeout / time.Second)
dsn.WriteTimeout = int(cfg.WriteTimeout / time.Second)
dsn.AllowNativePasswords = cfg.AllowNativePasswords
dsn.AllowOldPasswords = cfg.AllowOldPasswords
return dsn
}

Version int `yaml:"-"` // 版本自动检查,不可配置
// newMySQLConfig convert Dsn to go-sql-drive Config
func (env *Dsn) newMySQLConifg() (*mysql.Config, error) {
var err error
dsn := mysql.NewConfig()

dsn.User = env.User
dsn.Passwd = env.Password
dsn.Net = env.Net
dsn.Addr = env.Addr
dsn.DBName = env.Schema
dsn.Params = make(map[string]string)
for k, v := range env.Params {
dsn.Params[k] = v
}
dsn.Params["charset"] = env.Charset
dsn.Collation = env.Collation
dsn.Loc, err = time.LoadLocation(env.Loc)
if err != nil {
return nil, err
}
dsn.MaxAllowedPacket = env.MaxAllowedPacket
dsn.ServerPubKey = env.ServerPubKey
dsn.TLSConfig = env.TLS
dsn.Timeout = time.Duration(env.Timeout) * time.Second
dsn.ReadTimeout = time.Duration(env.ReadTimeout) * time.Second
dsn.WriteTimeout = time.Duration(env.WriteTimeout) * time.Second
dsn.AllowNativePasswords = env.AllowNativePasswords
dsn.AllowOldPasswords = env.AllowOldPasswords
return dsn, err
}

// 解析命令行DSN输入
func parseDSN(odbc string, d *Dsn) *Dsn {
dsn := newDSN(nil)
var addr, user, password, schema, charset string
if odbc == FormatDSN(d) {
return d
Expand Down Expand Up @@ -340,30 +409,38 @@ func parseDSN(odbc string, d *Dsn) *Dsn {
schema = "information_schema"
}

// 默认utf8mb4使用字符集
// 默认 utf8 使用字符集
if charset == "" {
charset = "utf8mb4"
charset = "utf8"
}

dsn := &Dsn{
Addr: addr,
User: user,
Password: password,
Schema: schema,
Charset: charset,
Disable: false,
Version: 999,
}
dsn.Addr = addr
dsn.User = user
dsn.Password = password
dsn.Schema = schema
dsn.Charset = charset
return dsn
}

// ParseDSN compatible with old version soar < 0.11.0
func ParseDSN(odbc string, d *Dsn) *Dsn {
cfg, err := mysql.ParseDSN(odbc)
if err != nil {
return parseDSN(odbc, d)
}
return newDSN(cfg)
}

// FormatDSN 格式化打印DSN
func FormatDSN(env *Dsn) string {
if env == nil || env.Disable {
return ""
}
// username:password@ip:port/schema?charset=xxx
return fmt.Sprintf("%s:%s@%s/%s?charset=%s", env.User, env.Password, env.Addr, env.Schema, env.Charset)
dsn, err := env.newMySQLConifg()
if err != nil {
return ""
}
return dsn.FormatDSN()
}

// SoarVersion soar version information
Expand Down Expand Up @@ -485,8 +562,8 @@ func readCmdFlags() error {

_ = flag.String("config", "", "Config file path")
// +++++++++++++++测试环境+++++++++++++++++
onlineDSN := flag.String("online-dsn", FormatDSN(Config.OnlineDSN), "OnlineDSN, 线上环境数据库配置, username:password@ip:port/schema")
testDSN := flag.String("test-dsn", FormatDSN(Config.TestDSN), "TestDSN, 测试环境数据库配置, username:password@ip:port/schema")
onlineDSN := flag.String("online-dsn", FormatDSN(Config.OnlineDSN), "OnlineDSN, 线上环境数据库配置, username:password@tcp(ip:port)/schema")
testDSN := flag.String("test-dsn", FormatDSN(Config.TestDSN), "TestDSN, 测试环境数据库配置, username:password@tcp(ip:port)/schema")
allowOnlineAsTest := flag.Bool("allow-online-as-test", Config.AllowOnlineAsTest, "AllowOnlineAsTest, 允许线上环境也可以当作测试环境")
dropTestTemporary := flag.Bool("drop-test-temporary", Config.DropTestTemporary, "DropTestTemporary, 是否清理测试环境产生的临时库表")
cleanupTestDatabase := flag.Bool("cleanup-test-database", Config.CleanupTestDatabase, "单次运行清理历史1小时前残余的测试库。")
Expand Down Expand Up @@ -569,8 +646,8 @@ func readCmdFlags() error {
}
flag.Parse()

Config.OnlineDSN = parseDSN(*onlineDSN, Config.OnlineDSN)
Config.TestDSN = parseDSN(*testDSN, Config.TestDSN)
Config.OnlineDSN = ParseDSN(*onlineDSN, Config.OnlineDSN)
Config.TestDSN = ParseDSN(*testDSN, Config.TestDSN)
Config.AllowOnlineAsTest = *allowOnlineAsTest
Config.DropTestTemporary = *dropTestTemporary
Config.CleanupTestDatabase = *cleanupTestDatabase
Expand Down
16 changes: 15 additions & 1 deletion common/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func TestReadConfigFile(t *testing.T) {
func TestParseDSN(t *testing.T) {
Log.Debug("Entering function: %s", GetFunctionName())
var dsns = []string{
// version < 0.11.0
"",
"user:password@hostname:3307/database",
"user:password@hostname:3307/database?charset=utf8",
Expand All @@ -81,11 +82,24 @@ func TestParseDSN(t *testing.T) {
"@:3307/database",
":3307/database",
"/database",
// go-sql-driver dsn
"user@unix(/path/to/socket)/dbname",
"root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local",
"user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true",
"user:password@/dbname?sql_mode=TRADITIONAL",
"user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci",
"id:password@tcp(your-amazonaws-uri.com:3306)/dbname",
"user@cloudsql(project-id:instance-name)/dbname",
"user@cloudsql(project-id:regionname:instance-name)/dbname",
"user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped",
"user:password@/dbname",
"user:password@/",
}

err := GoldenDiff(func() {
for _, dsn := range dsns {
pretty.Println(parseDSN(dsn, nil))
pretty.Println(dsn)
pretty.Println(ParseDSN(dsn, nil))
}
}, t.Name(), update)
if nil != err {
Expand Down
17 changes: 8 additions & 9 deletions common/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,18 @@ import "fmt"

func ExampleFormatDSN() {
Log.Debug("Entering function: %s", GetFunctionName())
dsxExp := &Dsn{
Addr: "127.0.0.1:3306",
Schema: "mysql",
User: "root",
Password: "1t'sB1g3rt",
Charset: "utf8mb4",
Disable: false,
}
dsxExp := newDSN(nil)
dsxExp.Addr = "127.0.0.1:3306"
dsxExp.Schema = "mysql"
dsxExp.User = "root"
dsxExp.Password = "1t'sB1g3rt"
dsxExp.Charset = "utf8mb4"
dsxExp.Disable = false

// 根据 &dsn 生成 dsnStr
fmt.Println(FormatDSN(dsxExp))

// Output: root:1t'sB1g3rt@127.0.0.1:3306/mysql?charset=utf8mb4
// Output: root:1t'sB1g3rt@tcp(127.0.0.1:3306)/mysql?charset=utf8mb4
Log.Debug("Exiting function: %s", GetFunctionName())
}

Expand Down
Loading

0 comments on commit 84a6702

Please sign in to comment.