From 57391aa6484082d6f0656d4491fdc04801ebc0e7 Mon Sep 17 00:00:00 2001 From: Kenneth Shaw Date: Mon, 27 Mar 2017 21:20:44 +0700 Subject: [PATCH] Initial work on password support --- drivers/drivers.go | 32 +++++++++++++++++++++++++++++-- drivers/mymysql.go | 9 +++++++++ drivers/odbc.go | 16 ++++++++++++++-- drivers/oracle.go | 9 ++++++++- drivers/pgx.go | 8 ++++++++ handler/handler.go | 47 ++++++++++++++++++++++++++++++++++++---------- main.go | 2 +- text/text.go | 2 +- 8 files changed, 108 insertions(+), 17 deletions(-) diff --git a/drivers/drivers.go b/drivers/drivers.go index 350cba2ed2c..378e2005977 100644 --- a/drivers/drivers.go +++ b/drivers/drivers.go @@ -2,15 +2,16 @@ package drivers import ( "runtime" + "strings" // mssql driver _ "github.com/denisenkom/go-mssqldb" // mysql driver - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" // postgres driver - _ "github.com/lib/pq" + "github.com/lib/pq" // sqlite3 driver _ "github.com/mattn/go-sqlite3" @@ -61,3 +62,30 @@ func init() { } } } + +var pwErr = map[string]func(error) bool{ + "mssql": func(err error) bool { + return strings.Contains(err.Error(), "Login failed for") + }, + "mysql": func(err error) bool { + if e, ok := err.(*mysql.MySQLError); ok { + return e.Number == 1045 + } + return false + }, + "postgres": func(err error) bool { + if e, ok := err.(*pq.Error); ok { + return e.Code.Name() == "invalid_password" + } + return false + }, +} + +// IsPasswordErr +func IsPasswordErr(name string, err error) bool { + if f, ok := pwErr[name]; ok { + return f(err) + } + + return false +} diff --git a/drivers/mymysql.go b/drivers/mymysql.go index 5a8eefdf096..cfaa8c13439 100644 --- a/drivers/mymysql.go +++ b/drivers/mymysql.go @@ -4,9 +4,18 @@ package drivers import ( // mymysql driver + _ "github.com/ziutek/mymysql/godrv" + "github.com/ziutek/mymysql/mysql" ) func init() { Drivers["mymysql"] = "mymysql" + + pwErr["mymysql"] = func(err error) bool { + if e, ok := err.(*mysql.Error); ok { + return e.Code == 1045 + } + return false + } } diff --git a/drivers/odbc.go b/drivers/odbc.go index 317de846c77..ecc01269c56 100644 --- a/drivers/odbc.go +++ b/drivers/odbc.go @@ -3,10 +3,22 @@ package drivers import ( - // odbc driver - _ "github.com/alexbrainman/odbc" + "strings" + + "github.com/alexbrainman/odbc" ) func init() { Drivers["odbc"] = "odbc" + + pwErr["odbc"] = func(err error) bool { + if e, ok := err.(*odbc.Error); ok { + msg := strings.ToLower(e.Error()) + return strings.Contains(msg, "failed") && + (strings.Contains(msg, "login") || + strings.Contains(msg, "authentication") || + strings.Contains(msg, "password")) + } + return false + } } diff --git a/drivers/oracle.go b/drivers/oracle.go index 40f3eb6a23d..a79ab91e35c 100644 --- a/drivers/oracle.go +++ b/drivers/oracle.go @@ -3,10 +3,17 @@ package drivers import ( - // oracle driver _ "gopkg.in/rana/ora.v4" ) func init() { Drivers["ora"] = "oracle" + pwErr["ora"] = func(err error) bool { + if e, ok := err.(interface { + Code() int + }); ok { + return e.Code() == 1017 + } + return false + } } diff --git a/drivers/pgx.go b/drivers/pgx.go index d9d0141d7ed..7ba912a99cd 100644 --- a/drivers/pgx.go +++ b/drivers/pgx.go @@ -4,6 +4,7 @@ package drivers import ( // pgx driver + "github.com/jackc/pgx/stdlib" "database/sql" @@ -14,6 +15,13 @@ import ( func init() { Drivers["pgx"] = "pgx" + + pwErr["pgx"] = func(err error) bool { + if e, ok := err.(pgx.PgError); ok { + return e.Code == "28P01" + } + return false + } } const ( diff --git a/handler/handler.go b/handler/handler.go index 349a822ae16..d5ad86f6065 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "io" + "net/url" "os" "os/user" "path/filepath" @@ -130,7 +131,7 @@ func (h *Handler) Run() error { // run res, err = r.Run(h) - if err != nil { + if err != nil && err != rline.ErrInterrupt { fmt.Fprintf(stderr, "error: %v", err) fmt.Fprintln(stderr) continue @@ -293,24 +294,50 @@ func (h *Handler) Open(params ...string) error { f = drivers.PgxOpen(h.u) } - // connect - h.db, err = f(h.u.Driver, h.u.DSN) - if err != nil { - return err - } - // force statement parse settings isPG := h.u.Driver == "postgres" || h.u.Driver == "pgx" stmt.AllowDollar(isPG)(h.buf) stmt.AllowMultilineComments(isPG)(h.buf) - // do ping to force an error (if any) - err = h.WrapError(h.db.Ping()) + // connect + h.db, err = f(h.u.Driver, h.u.DSN) + if err != nil && !drivers.IsPasswordErr(h.u.Driver, err) { + return err + } else if err == nil { + // do ping to force an error (if any) + err = h.db.Ping() + if err == nil { + return nil + } + } + + // close connection if err != nil { h.Close() } - return err + // bail without getting password + if !drivers.IsPasswordErr(h.u.Driver, err) || len(params) > 1 || !h.l.Interactive() { + return h.WrapError(err) + } + + // print the error + fmt.Fprintf(h.l.Stderr(), "error: %v", h.WrapError(err)) + fmt.Fprintln(h.l.Stderr()) + + // otherwise, try to collect a password ... + user := h.user.Username + if h.u.User != nil { + user = h.u.User.Username() + } + pass, err := h.l.Password() + if err != nil { + return err + } + + // reconnect using the user/pass ... + h.u.User = url.UserPassword(user, pass) + return h.Open(h.u.String()) } // Close closes the database connection if it is open. diff --git a/main.go b/main.go index 44a8843d2c8..87a584061a8 100644 --- a/main.go +++ b/main.go @@ -54,7 +54,7 @@ func main() { // run err = run(args, cur) - if err != nil && err != io.EOF { + if err != nil && err != io.EOF && err != rline.ErrInterrupt { fmt.Fprintf(os.Stderr, "error: %v\n", err) // extra output for when the oracle driver is not available diff --git a/text/text.go b/text/text.go index 3c6a3515590..a556369dad0 100644 --- a/text/text.go +++ b/text/text.go @@ -33,7 +33,7 @@ var ( ConnInfo = `You are connected with driver %s (%s)` - EnterPassword = `Enter Password` + EnterPassword = `Password: ` HelpDesc string