Skip to content

Commit

Permalink
enforce required flags
Browse files Browse the repository at this point in the history
  • Loading branch information
dixudx committed Jul 27, 2017
1 parent 34594c7 commit a219f11
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
22 changes: 22 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,9 @@ func (c *Command) execute(a []string) (err error) {
c.PreRun(c, argWoFlags)
}

if err := c.ValidateRequiredFlags(); err != nil {
return err
}
if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
return err
Expand Down Expand Up @@ -756,6 +759,25 @@ func (c *Command) ValidateArgs(args []string) error {
return c.Args(c, args)
}

func (c *Command) ValidateRequiredFlags() error {
flags := c.Flags()
missingFlagNames := []string{}
flags.VisitAll(func(pflag *flag.Flag) {
requiredAnnotation, found := pflag.Annotations[BashCompOneRequiredFlag]
if !found {
return
}
if (requiredAnnotation[0] == "true") && !pflag.Changed {
missingFlagNames = append(missingFlagNames, pflag.Name)
}
})

if len(missingFlagNames) > 0 {
return fmt.Errorf("Required flags \"%s\" have/has not been set", strings.Join(missingFlagNames, "\", \""))
}
return nil
}

// InitDefaultHelpFlag adds default help flag to c.
// It is called automatically by executing the c or by calling help and usage.
// If c already has help flag, it will do nothing.
Expand Down
50 changes: 50 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,53 @@ func TestSetHelpCommand(t *testing.T) {
t.Errorf("Expected to contain %q message, but got %q", correctMessage, output.String())
}
}

func TestRequiredFlags(t *testing.T) {
c := &Command{Use: "c", Run: func(*Command, []string) {}}
output := new(bytes.Buffer)
c.SetOutput(output)
c.Flags().String("foo1", "", "required foo1")
c.MarkFlagRequired("foo1")
c.Flags().String("foo2", "", "required foo2")
c.MarkFlagRequired("foo2")
c.Flags().String("bar", "", "optional bar")

expectedFmt := "Required flags %s have/has not been set"
expected := fmt.Sprintf(expectedFmt, fmt.Sprintf("%q, %q", "foo1", "foo2"))

if err := c.Execute(); err != nil {
if err.Error() != expected {
t.Errorf("expected %v, got %v", expected, err.Error())
}
}
}

func TestPersistentRequiredFlags(t *testing.T) {
parent := &Command{Use: "parent", Run: func(*Command, []string) {}}
output := new(bytes.Buffer)
parent.SetOutput(output)
parent.PersistentFlags().String("foo1", "", "required foo1")
parent.MarkPersistentFlagRequired("foo1")
parent.PersistentFlags().String("foo2", "", "required foo2")
parent.MarkPersistentFlagRequired("foo2")
parent.Flags().String("foo3", "", "optional foo3")

child := &Command{Use: "child", Run: func(*Command, []string) {}}
child.Flags().String("bar1", "", "required bar1")
child.MarkFlagRequired("bar1")
child.Flags().String("bar2", "", "required bar2")
child.MarkFlagRequired("bar2")
child.Flags().String("bar3", "", "optional bar3")

parent.AddCommand(child)
parent.SetArgs([]string{"child"})

expectedFmt := "Required flags %s have/has not been set"
expected := fmt.Sprintf(expectedFmt, fmt.Sprintf("%q, %q, %q, %q", "bar1", "bar2", "foo1", "foo2"))

if err := parent.Execute(); err != nil {
if err.Error() != expected {
t.Errorf("expected %v, got %v", expected, err.Error())
}
}
}

0 comments on commit a219f11

Please sign in to comment.