Skip to content

Commit

Permalink
Initial work on password support
Browse files Browse the repository at this point in the history
  • Loading branch information
Kenneth Shaw committed Mar 27, 2017
1 parent 9b377a9 commit 57391aa
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 17 deletions.
32 changes: 30 additions & 2 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package drivers

import (
"runtime"
"strings"

// mssql driver
_ "github.com/denisenkom/go-mssqldb"

// mysql driver
_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"

// postgres driver
_ "github.com/lib/pq"
"github.com/lib/pq"

// sqlite3 driver
_ "github.com/mattn/go-sqlite3"
Expand Down Expand Up @@ -61,3 +62,30 @@ func init() {
}
}
}

var pwErr = map[string]func(error) bool{
"mssql": func(err error) bool {
return strings.Contains(err.Error(), "Login failed for")
},
"mysql": func(err error) bool {
if e, ok := err.(*mysql.MySQLError); ok {
return e.Number == 1045
}
return false
},
"postgres": func(err error) bool {
if e, ok := err.(*pq.Error); ok {
return e.Code.Name() == "invalid_password"
}
return false
},
}

// IsPasswordErr
func IsPasswordErr(name string, err error) bool {
if f, ok := pwErr[name]; ok {
return f(err)
}

return false
}
9 changes: 9 additions & 0 deletions drivers/mymysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@ package drivers

import (
// mymysql driver

_ "github.com/ziutek/mymysql/godrv"
"github.com/ziutek/mymysql/mysql"
)

func init() {
Drivers["mymysql"] = "mymysql"

pwErr["mymysql"] = func(err error) bool {
if e, ok := err.(*mysql.Error); ok {
return e.Code == 1045
}
return false
}
}
16 changes: 14 additions & 2 deletions drivers/odbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
package drivers

import (
// odbc driver
_ "github.com/alexbrainman/odbc"
"strings"

"github.com/alexbrainman/odbc"
)

func init() {
Drivers["odbc"] = "odbc"

pwErr["odbc"] = func(err error) bool {
if e, ok := err.(*odbc.Error); ok {
msg := strings.ToLower(e.Error())
return strings.Contains(msg, "failed") &&
(strings.Contains(msg, "login") ||
strings.Contains(msg, "authentication") ||
strings.Contains(msg, "password"))
}
return false
}
}
9 changes: 8 additions & 1 deletion drivers/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
package drivers

import (
// oracle driver
_ "gopkg.in/rana/ora.v4"
)

func init() {
Drivers["ora"] = "oracle"
pwErr["ora"] = func(err error) bool {
if e, ok := err.(interface {
Code() int
}); ok {
return e.Code() == 1017
}
return false
}
}
8 changes: 8 additions & 0 deletions drivers/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package drivers

import (
// pgx driver

"github.com/jackc/pgx/stdlib"

"database/sql"
Expand All @@ -14,6 +15,13 @@ import (

func init() {
Drivers["pgx"] = "pgx"

pwErr["pgx"] = func(err error) bool {
if e, ok := err.(pgx.PgError); ok {
return e.Code == "28P01"
}
return false
}
}

const (
Expand Down
47 changes: 37 additions & 10 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"io"
"net/url"
"os"
"os/user"
"path/filepath"
Expand Down Expand Up @@ -130,7 +131,7 @@ func (h *Handler) Run() error {

// run
res, err = r.Run(h)
if err != nil {
if err != nil && err != rline.ErrInterrupt {
fmt.Fprintf(stderr, "error: %v", err)
fmt.Fprintln(stderr)
continue
Expand Down Expand Up @@ -293,24 +294,50 @@ func (h *Handler) Open(params ...string) error {
f = drivers.PgxOpen(h.u)
}

// connect
h.db, err = f(h.u.Driver, h.u.DSN)
if err != nil {
return err
}

// force statement parse settings
isPG := h.u.Driver == "postgres" || h.u.Driver == "pgx"
stmt.AllowDollar(isPG)(h.buf)
stmt.AllowMultilineComments(isPG)(h.buf)

// do ping to force an error (if any)
err = h.WrapError(h.db.Ping())
// connect
h.db, err = f(h.u.Driver, h.u.DSN)
if err != nil && !drivers.IsPasswordErr(h.u.Driver, err) {
return err
} else if err == nil {
// do ping to force an error (if any)
err = h.db.Ping()
if err == nil {
return nil
}
}

// close connection
if err != nil {
h.Close()
}

return err
// bail without getting password
if !drivers.IsPasswordErr(h.u.Driver, err) || len(params) > 1 || !h.l.Interactive() {
return h.WrapError(err)
}

// print the error
fmt.Fprintf(h.l.Stderr(), "error: %v", h.WrapError(err))
fmt.Fprintln(h.l.Stderr())

// otherwise, try to collect a password ...
user := h.user.Username
if h.u.User != nil {
user = h.u.User.Username()
}
pass, err := h.l.Password()
if err != nil {
return err
}

// reconnect using the user/pass ...
h.u.User = url.UserPassword(user, pass)
return h.Open(h.u.String())
}

// Close closes the database connection if it is open.
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func main() {

// run
err = run(args, cur)
if err != nil && err != io.EOF {
if err != nil && err != io.EOF && err != rline.ErrInterrupt {
fmt.Fprintf(os.Stderr, "error: %v\n", err)

// extra output for when the oracle driver is not available
Expand Down
2 changes: 1 addition & 1 deletion text/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var (

ConnInfo = `You are connected with driver %s (%s)`

EnterPassword = `Enter Password`
EnterPassword = `Password: `

HelpDesc string

Expand Down

0 comments on commit 57391aa

Please sign in to comment.