Skip to content

Commit

Permalink
Revert "Further code cleanup"
Browse files Browse the repository at this point in the history
This reverts commit 9896187.
  • Loading branch information
kenshaw committed Jan 17, 2021
1 parent 0400204 commit 1061148
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 32 deletions.
22 changes: 19 additions & 3 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"

"github.com/alecthomas/kingpin"

"github.com/xo/usql/text"
)

Expand All @@ -18,21 +19,24 @@ type CommandOrFile struct {

// Args are the command line arguments.
type Args struct {
DSN string
DSN string

CommandOrFiles []CommandOrFile
Out string
ForcePassword bool
NoPassword bool
NoRC bool
SingleTransaction bool
Variables []string
PVariables []string

Variables []string
PVariables []string
}

func (args *Args) Next() (string, bool, error) {
if len(args.CommandOrFiles) == 0 {
return "", false, io.EOF
}

cmd := args.CommandOrFiles[0]
args.CommandOrFiles = args.CommandOrFiles[1:]
return cmd.Value, cmd.Command, nil
Expand Down Expand Up @@ -81,25 +85,32 @@ func (p pset) IsCumulative() bool {

func NewArgs() *Args {
args := &Args{}

// set usage template
kingpin.UsageTemplate(text.UsageTemplate())

kingpin.Arg("dsn", "database url").StringVar(&args.DSN)

// command / file flags
kingpin.Flag("command", "run only single command (SQL or internal) and exit").Short('c').SetValue(commandOrFile{args, true})
kingpin.Flag("file", "execute commands from file and exit").Short('f').SetValue(commandOrFile{args, false})

// general flags
kingpin.Flag("no-password", "never prompt for password").Short('w').BoolVar(&args.NoPassword)
kingpin.Flag("no-rc", "do not read start up file").Short('X').BoolVar(&args.NoRC)
kingpin.Flag("out", "output file").Short('o').StringVar(&args.Out)
kingpin.Flag("password", "force password prompt (should happen automatically)").Short('W').BoolVar(&args.ForcePassword)
kingpin.Flag("single-transaction", "execute as a single transaction (if non-interactive)").Short('1').BoolVar(&args.SingleTransaction)
kingpin.Flag("set", "set variable NAME to VALUE").Short('v').PlaceHolder(", --variable=NAME=VALUE").StringsVar(&args.Variables)

// pset
kingpin.Flag("pset", `set printing option VAR to ARG (see \pset command)`).Short('P').PlaceHolder("VAR[=ARG]").StringsVar(&args.PVariables)

// pset flags
kingpin.Flag("field-separator", `field separator for unaligned output (default, "|")`).Short('F').SetValue(pset{args, []string{"fieldsep=%q", "fieldsep_zero=off"}})
kingpin.Flag("record-separator", `record separator for unaligned output (default, \n)`).Short('R').SetValue(pset{args, []string{"recordsep=%q", "recordsep_zero=off"}})
kingpin.Flag("table-attr", "set HTML table tag attributes (e.g., width, border)").Short('T').SetValue(pset{args, []string{"tableattr=%q"}})

type psetconfig struct {
long string
short rune
Expand Down Expand Up @@ -127,17 +138,22 @@ func NewArgs() *Args {
return nil
}).Bool()
}

// add --set as a hidden alias for --variable
kingpin.Flag("variable", "set variable NAME to VALUE").Hidden().StringsVar(&args.Variables)

// add --version flag
kingpin.Flag("version", "display version and exit").PreAction(func(*kingpin.ParseContext) error {
fmt.Fprintln(os.Stdout, text.CommandName, text.CommandVersion)
os.Exit(0)
return nil
}).Short('V').Bool()

// hide help flag
kingpin.HelpFlag.Short('h').Hidden()

// parse
kingpin.Parse()

return args
}
4 changes: 2 additions & 2 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ func ForceParams(u *dburl.URL) {

// Open opens a sql.DB connection for the registered driver.
func Open(u *dburl.URL) (*sql.DB, error) {
var err error
d, ok := drivers[u.Driver]
if !ok {
return nil, WrapErr(u.Driver, text.ErrDriverNotAvailable)
}
f := sql.Open
if d.Open != nil {
var err error
f, err = d.Open(u)
if err != nil {
return nil, WrapErr(u.Driver, err)
Expand Down Expand Up @@ -250,8 +250,8 @@ func CanChangePassword(u *dburl.URL) error {
// from User.
func ChangePassword(u *dburl.URL, db DB, user, new, old string) (string, error) {
if d, ok := drivers[u.Driver]; ok && d.ChangePassword != nil {
var err error
if user == "" {
var err error
user, err = User(u, db)
if err != nil {
return "", err
Expand Down
16 changes: 11 additions & 5 deletions env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ func Getvar(s string) (bool, string, error) {
// OpenFile opens a file for reading, returning the full, expanded path of the
// file. All callers are responsible for closing the returned file.
func OpenFile(u *user.User, path string, relative bool) (string, *os.File, error) {
path, err := filepath.EvalSymlinks(expand(u, path))
var err error
path, err = filepath.EvalSymlinks(expand(u, path))
switch {
case err != nil && os.IsNotExist(err):
return "", nil, text.ErrNoSuchFileOrDirectory
Expand All @@ -93,22 +94,26 @@ func OpenFile(u *user.User, path string, relative bool) (string, *os.File, error

// EditFile edits a file. If path is empty, then a temporary file will be created.
func EditFile(u *user.User, path, line, s string) ([]rune, error) {
var err error
ed := Getenv(text.CommandUpper()+"_EDITOR", "EDITOR", "VISUAL")
if ed == "" {
return nil, text.ErrNoEditorDefined
}
if path != "" {
path = expand(u, path)
} else {
f, err := temp.File("", text.CommandLower(), "sql")
var f *os.File
f, err = temp.File("", text.CommandLower(), "sql")
if err != nil {
return nil, err
}
if err = f.Close(); err != nil {
err = f.Close()
if err != nil {
return nil, err
}
path = f.Name()
if err := ioutil.WriteFile(path, []byte(strings.TrimSuffix(s, "\n")+"\n"), 0o644); err != nil {
err = ioutil.WriteFile(path, []byte(strings.TrimSuffix(s, "\n")+"\n"), 0o644)
if err != nil {
return nil, err
}
}
Expand All @@ -127,7 +132,8 @@ func EditFile(u *user.User, path, line, s string) ([]rune, error) {
c.Stdout = os.Stdout
c.Stderr = os.Stderr
// run
if err := c.Run(); err != nil {
err = c.Run()
if err != nil {
return nil, err
}
// read
Expand Down
46 changes: 27 additions & 19 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,9 @@ func (h *Handler) outputHighlighter(s string) string {
var cmd, final string
for {
cmd, _, err = st.Next()
switch {
case err != nil && err != io.EOF:
if err != nil && err != io.EOF {
return s + endl
case err == io.EOF:
} else if err == io.EOF {
break
}
if st.Ready() || cmd != "" {
Expand Down Expand Up @@ -198,6 +197,7 @@ func (h *Handler) Run() error {
fmt.Fprintln(h.l.Stdout())
}
for {
var err error
var execute bool
// set prompt
if iactive {
Expand Down Expand Up @@ -527,6 +527,7 @@ func (h *Handler) forceParams(u *dburl.URL) {
// Password collects a password from input, and returns a modified DSN
// including the collected password.
func (h *Handler) Password(dsn string) (string, error) {
var err error
if dsn == "" {
return "", text.ErrMissingDSN
}
Expand Down Expand Up @@ -613,20 +614,20 @@ func (h *Handler) ChangePassword(user string) (string, error) {
if !h.l.Interactive() {
return "", text.ErrNotInteractive
}
if err := drivers.CanChangePassword(h.u); err != nil {
var err error
if err = drivers.CanChangePassword(h.u); err != nil {
return "", err
}
var newpw, newpw2, oldpw string
// ask for previous password
if user == "" && drivers.RequirePreviousPassword(h.u) {
var err error
if oldpw, err = h.l.Password(text.EnterPreviousPassword); err != nil {
oldpw, err = h.l.Password(text.EnterPreviousPassword)
if err != nil {
return "", err
}
}
// attempt to get passwords
for i := 0; i < 3; i++ {
var err error
if newpw, err = h.l.Password(text.NewPassword); err != nil {
return "", err
}
Expand Down Expand Up @@ -749,6 +750,7 @@ func (h *Handler) execExec(w io.Writer, prefix, qstr string, qtyp bool, _ string

// query executes a query against the database.
func (h *Handler) query(w io.Writer, _, qstr string) error {
var err error
// run query
q, err := h.DB().Query(qstr)
if err != nil {
Expand All @@ -761,6 +763,7 @@ func (h *Handler) query(w io.Writer, _, qstr string) error {

// execRows executes all the columns in the row.
func (h *Handler) execRows(w io.Writer, q *sql.Rows) error {
var err error
// get columns
cols, err := drivers.Columns(h.u, q)
if err != nil {
Expand Down Expand Up @@ -788,24 +791,26 @@ func (h *Handler) execRows(w io.Writer, q *sql.Rows) error {

// scan scans a row.
func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) {
var err error
// scan to []interface{}
r := make([]interface{}, clen)
for i := range r {
r[i] = new(interface{})
}
if err := q.Scan(r...); err != nil {
if err = q.Scan(r...); err != nil {
return nil, err
}
// get conversion funcs
cb, cm, cs, cd := drivers.ConvertBytes(h.u), drivers.ConvertMap(h.u), drivers.ConvertSlice(h.u), drivers.ConvertDefault(h.u)
cb, cm, cs, cd := drivers.ConvertBytes(h.u), drivers.ConvertMap(h.u),
drivers.ConvertSlice(h.u), drivers.ConvertDefault(h.u)
row := make([]string, clen)
for n, z := range r {
j := z.(*interface{})
switch x := (*j).(type) {
case []byte:
if x != nil {
var err error
if row[n], err = cb(x, tfmt); err != nil {
row[n], err = cb(x, tfmt)
if err != nil {
return nil, err
}
}
Expand All @@ -817,32 +822,33 @@ func (h *Handler) scan(q *sql.Rows, clen int, tfmt string) ([]string, error) {
row[n] = x.String()
case map[string]interface{}:
if x != nil {
var err error
if row[n], err = cm(x); err != nil {
row[n], err = cm(x)
if err != nil {
return nil, err
}
}
case []interface{}:
if x != nil {
var err error
if row[n], err = cs(x); err != nil {
row[n], err = cs(x)
if err != nil {
return nil, err
}
}
default:
if x != nil {
var err error
if row[n], err = cd(x); err != nil {
row[n], err = cd(x)
if err != nil {
return nil, err
}
}
}
}
return row, nil
return row, err
}

// exec does a database exec.
func (h *Handler) exec(w io.Writer, typ, qstr string) error {
var err error
res, err := h.DB().Exec(qstr)
if err != nil {
return err
Expand Down Expand Up @@ -871,7 +877,8 @@ func (h *Handler) Begin() error {
return text.ErrPreviousTransactionExists
}
var err error
if h.tx, err = h.db.Begin(); err != nil {
h.tx, err = h.db.Begin()
if err != nil {
return drivers.WrapErr(h.u.Driver, err)
}
return nil
Expand Down Expand Up @@ -913,6 +920,7 @@ func (h *Handler) Rollback() error {

// Include includes the specified path.
func (h *Handler) Include(path string, relative bool) error {
var err error
if relative && !filepath.IsAbs(path) {
path = filepath.Join(h.wd, path)
}
Expand Down
Loading

0 comments on commit 1061148

Please sign in to comment.