Skip to content

Commit

Permalink
More complete wrappers in tyro.extras (for torchrun, etc) (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi authored Oct 11, 2024
1 parent 34f3b14 commit 8828943
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/tyro/extras/_subcommand_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def cli(
description: Optional[str] = None,
args: Optional[Sequence[str]] = None,
use_underscores: bool = False,
console_outputs: bool = True,
config: Optional[Sequence[Any]] = None,
sort_subcommands: bool = False,
) -> Any:
"""Run the command-line interface.
Expand All @@ -110,6 +112,9 @@ def cli(
This primarily impacts helptext; underscores and hyphens are treated equivalently
when parsing happens. We default helptext to hyphens to follow the GNU style guide.
https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html
console_outputs: If set to `False`, parsing errors and help messages will be
suppressed.
config: Sequence of config marker objects, from `tyro.conf`.
sort_subcommands: If True, sort the subcommands alphabetically by name.
"""
assert self._subcommands is not None
Expand All @@ -129,6 +134,8 @@ def cli(
description=description,
args=args,
use_underscores=use_underscores,
console_outputs=console_outputs,
config=config,
)
else:
return tyro.extras.subcommand_cli_from_dict(
Expand All @@ -137,4 +144,6 @@ def cli(
description=description,
args=args,
use_underscores=use_underscores,
console_outputs=console_outputs,
config=config,
)
9 changes: 8 additions & 1 deletion src/tyro/extras/_subcommand_cli_from_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing_extensions import Annotated

from tyro.conf._markers import Marker

from .._cli import cli
from ..conf import subcommand

Expand All @@ -17,6 +19,7 @@ def subcommand_cli_from_dict(
args: Optional[Sequence[str]] = None,
use_underscores: bool = False,
console_outputs: bool = True,
config: Optional[Sequence[Marker]] = None,
) -> T: ...


Expand All @@ -31,6 +34,7 @@ def subcommand_cli_from_dict(
args: Optional[Sequence[str]] = None,
use_underscores: bool = False,
console_outputs: bool = True,
config: Optional[Sequence[Marker]] = None,
) -> Any: ...


Expand All @@ -42,6 +46,7 @@ def subcommand_cli_from_dict(
args: Optional[Sequence[str]] = None,
use_underscores: bool = False,
console_outputs: bool = True,
config: Optional[Sequence[Marker]] = None,
) -> Any:
"""Generate a subcommand CLI from a dictionary of functions.
Expand Down Expand Up @@ -93,6 +98,7 @@ def subcommand_cli_from_dict(
supressed. This can be useful for distributed settings, where `tyro.cli()`
is called from multiple workers but we only want console outputs from the
main one.
config: Sequence of config marker objects, from `tyro.conf`.
"""
# We need to form a union type, which requires at least two elements.
assert len(subcommands) >= 2, "At least two subcommands are required."
Expand All @@ -115,5 +121,6 @@ def subcommand_cli_from_dict(
description=description,
args=args,
use_underscores=use_underscores,
console_outputs=console_outputs
console_outputs=console_outputs,
config=config,
)

0 comments on commit 8828943

Please sign in to comment.