Skip to content

Commit

Permalink
improve args API
Browse files Browse the repository at this point in the history
  • Loading branch information
posener committed Nov 8, 2019
1 parent 7276cbc commit c489499
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 152 deletions.
69 changes: 69 additions & 0 deletions args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package subcmd

import (
"fmt"
"strconv"
)

// ArgsStr are string positional arguments. If it is created with cap > 0, it will be used to define
// the number of required arguments.
//
// Usage
//
// To get a list of arbitrary number of arguments:
//
// root := subcmd.Root()
//
// var subcmd.ArgsStr args
// root.ArgsVar(&args, "[arg...]", "list of arguments")
//
// To get a list of specific number of arguments:
//
// root := subcmd.Root()
//
// args := make(subcmd.ArgsStr, 3)
// root.ArgsVar(&args, "[arg1] [arg2] [arg3]", "list of 3 arguments")
type ArgsStr []string

func (a *ArgsStr) Set(args []string) error {
if cap(*a) > 0 && len(args) != cap(*a) {
return fmt.Errorf("required %d positional args, got %v", cap(*a), args)
}
*a = args
return nil
}

// ArgsInt are int positional arguments. If it is created with cap > 0, it will be used to define
// the number of required arguments.
//
// Usage
//
// To get a list of arbitrary number of integers:
//
// root := subcmd.Root()
//
// var subcmd.ArgsInt args
// root.ArgsVar(&args, "[int...]", "list of integer args")
//
// To get a list of specific number of integers:
//
// root := subcmd.Root()
//
// args := make(subcmd.ArgsInt, 3)
// root.ArgsVar(&args, "[int1] [int2] [int3]", "list of 3 integers")
type ArgsInt []int

func (a *ArgsInt) Set(args []string) error {
if cap(*a) > 0 && len(args) != cap(*a) {
return fmt.Errorf("required %d positional args, got %v", cap(*a), args)
}
*a = (*a)[:0] // Reset length to 0.
for i, arg := range args {
v, err := strconv.Atoi(arg)
if err != nil {
return fmt.Errorf("invalid int positional argument at position %d with value %v", i, arg)
}
*a = append(*a, v)
}
return nil
}
74 changes: 74 additions & 0 deletions args_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package subcmd

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestArgs(t *testing.T) {
t.Parallel()

t.Run("no cap", func(t *testing.T) {
var args ArgsStr
err := args.Set([]string{"a", "b"})
require.NoError(t, err)
assert.Equal(t, []string{"a", "b"}, []string(args))
})

t.Run("with cap", func(t *testing.T) {
var args = make(ArgsStr, 2)
err := args.Set([]string{"a", "b"})
require.NoError(t, err)
assert.Equal(t, []string{"a", "b"}, []string(args))
})

t.Run("too many args", func(t *testing.T) {
var args = make(ArgsStr, 2)
err := args.Set([]string{"a", "b", "c"})
require.Error(t, err)
})

t.Run("not enough args", func(t *testing.T) {
var args = make(ArgsStr, 2)
err := args.Set([]string{"a"})
require.Error(t, err)
})
}

func TestArgsInt(t *testing.T) {
t.Parallel()

t.Run("no cap", func(t *testing.T) {
var args ArgsInt
err := args.Set([]string{"1", "2"})
require.NoError(t, err)
assert.Equal(t, []int{1, 2}, []int(args))
})

t.Run("with cap", func(t *testing.T) {
var args = make(ArgsInt, 2)
err := args.Set([]string{"1", "2"})
require.NoError(t, err)
assert.Equal(t, []int{1, 2}, []int(args))
})

t.Run("bad value", func(t *testing.T) {
var args ArgsInt
err := args.Set([]string{"a"})
require.Error(t, err)
})

t.Run("too many args", func(t *testing.T) {
var args = make(ArgsInt, 2)
err := args.Set([]string{"1", "2", "3"})
require.Error(t, err)
})

t.Run("not enough args", func(t *testing.T) {
var args = make(ArgsInt, 2)
err := args.Set([]string{"1"})
require.Error(t, err)
})
}
51 changes: 0 additions & 51 deletions example/main.go

This file was deleted.

38 changes: 38 additions & 0 deletions example_1_subcommand_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package subcmd_test

import (
"fmt"

"github.com/posener/subcmd"
)

var (
// Define a root command. Some options can be set using the `Opt*` functions. It returns a
// `*Cmd` object.
root = subcmd.Root()
// The `*Cmd` object can be used as the standard library `flag.FlagSet`.
flag0 = root.String("flag0", "", "root stringflag")
// From each command object, a sub command can be created. This can be done recursively.
sub1 = root.SubCommand("sub1", "first sub command")
// Each sub command can have flags attached.
flag1 = sub1.String("flag1", "", "sub1 string flag")
sub2 = root.SubCommand("sub2", "second sub command")
flag2 = sub1.Int("flag2", 0, "sub2 int flag")
)

// Definition and usage of sub commands and sub commands flags.
func Example() {
// In the example we use `Parse()` for a given list of command line arguments. This is useful
// for testing, but should be replaced with `root.ParseArgs()` in `main()`
root.Parse([]string{"cmd", "sub1", "-flag1", "value"})

// Usually the program should switch over the sub commands. The chosen sub command will return
// true for the `Parsed()` method.
switch {
case sub1.Parsed():
fmt.Printf("Called sub1 with flag: %s", *flag1)
case sub2.Parsed():
fmt.Printf("Called sub2 with flag: %d", *flag2)
}
// Output: Called sub1 with flag: value
}
71 changes: 48 additions & 23 deletions flagset.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package subcmd

import (
"flag"
"io"
"time"
)

Expand All @@ -11,58 +10,84 @@ type flagSet struct {
*flag.FlagSet
}

func (f *flagSet) Bool(name string, value bool, usage string) *bool {
return f.FlagSet.Bool(name, value, usage)
func (f *flagSet) Parsed() bool { return f.FlagSet.Parsed() }

func (f *flagSet) Set(name, value string) error {
return f.FlagSet.Set(name, value)
}

func (f *flagSet) Duration(name string, value time.Duration, usage string) *time.Duration {
return f.FlagSet.Duration(name, value, usage)
func (f *flagSet) Var(value flag.Value, name string, usage string) {
f.FlagSet.Var(value, name, usage)
}

func (f *flagSet) Float64(name string, value float64, usage string) *float64 {
return f.FlagSet.Float64(name, value, usage)
func (f *flagSet) Visit(fn func(*flag.Flag)) {
f.FlagSet.Visit(fn)
}

func (f *flagSet) VisitAll(fn func(*flag.Flag)) {
f.FlagSet.VisitAll(fn)
}

func (f *flagSet) String(name string, value string, usage string) *string {
return f.FlagSet.String(name, value, usage)
}

func (f *flagSet) StringVar(p *string, name string, value string, usage string) {
f.FlagSet.StringVar(p, name, value, usage)
}

func (f *flagSet) Bool(name string, value bool, usage string) *bool {
return f.FlagSet.Bool(name, value, usage)
}

func (f *flagSet) BoolVar(p *bool, name string, value bool, usage string) {
f.FlagSet.BoolVar(p, name, value, usage)
}

func (f *flagSet) Int(name string, value int, usage string) *int {
return f.FlagSet.Int(name, value, usage)
}

func (f *flagSet) Int64(name string, value int64, usage string) *int64 {
return f.FlagSet.Int64(name, value, usage)
func (f *flagSet) IntVar(p *int, name string, value int, usage string) {
f.FlagSet.IntVar(p, name, value, usage)
}

func (f *flagSet) Output() io.Writer {
return f.FlagSet.Output()
func (f *flagSet) Int64(name string, value int64, usage string) *int64 {
return f.FlagSet.Int64(name, value, usage)
}

func (f *flagSet) Parsed() bool {
return f.FlagSet.Parsed()
func (f *flagSet) Int64Var(p *int64, name string, value int64, usage string) {
f.FlagSet.Int64Var(p, name, value, usage)
}

func (f *flagSet) Set(name, value string) error {
return f.FlagSet.Set(name, value)
func (f *flagSet) Float64(name string, value float64, usage string) *float64 {
return f.FlagSet.Float64(name, value, usage)
}

func (f *flagSet) String(name string, value string, usage string) *string {
return f.FlagSet.String(name, value, usage)
func (f *flagSet) Float64Var(p *float64, name string, value float64, usage string) {
f.FlagSet.Float64Var(p, name, value, usage)
}

func (f *flagSet) Uint(name string, value uint, usage string) *uint {
return f.FlagSet.Uint(name, value, usage)
}

func (f *flagSet) UintVar(p *uint, name string, value uint, usage string) {
f.FlagSet.UintVar(p, name, value, usage)
}

func (f *flagSet) Uint64(name string, value uint64, usage string) *uint64 {
return f.FlagSet.Uint64(name, value, usage)
}

func (f *flagSet) Var(value flag.Value, name string, usage string) {
f.FlagSet.Var(value, name, usage)
func (f *flagSet) UintVar64(p *uint64, name string, value uint64, usage string) {
f.FlagSet.Uint64Var(p, name, value, usage)
}

func (f *flagSet) Visit(fn func(*flag.Flag)) {
f.FlagSet.Visit(fn)
func (f *flagSet) Duration(name string, value time.Duration, usage string) *time.Duration {
return f.FlagSet.Duration(name, value, usage)
}

func (f *flagSet) VisitAll(fn func(*flag.Flag)) {
f.FlagSet.VisitAll(fn)
func (f *flagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
f.FlagSet.DurationVar(p, name, value, usage)
}
1 change: 1 addition & 0 deletions goreadme.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"recursive": false,
"badges": {
"travis_ci": true,
"code_cov": true,
Expand Down
Loading

0 comments on commit c489499

Please sign in to comment.