Skip to content

Commit

Permalink
Numeric tower improvements + test (#193)
Browse files Browse the repository at this point in the history
* Improvements for numeric tower edge cases

* Fix test for Python <3.10
  • Loading branch information
brentyi authored Nov 5, 2024
1 parent 24bbac0 commit 42a7f80
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 15 deletions.
38 changes: 37 additions & 1 deletion src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -473,7 +474,8 @@ def narrow_union_type(typ: TypeOrCallable, default_instance: Any) -> TypeOrCalla

try:
if default_instance not in MISSING_SINGLETONS and not any(
isinstance(default_instance, o) for o in options_unwrapped
isinstance_with_fuzzy_numeric_tower(default_instance, o) is not False
for o in options_unwrapped
):
warnings.warn(
f"{type(default_instance)} does not match any type in Union:"
Expand All @@ -486,6 +488,40 @@ def narrow_union_type(typ: TypeOrCallable, default_instance: Any) -> TypeOrCalla
return typ


def isinstance_with_fuzzy_numeric_tower(
obj: Any, classinfo: Type
) -> Union[bool, Literal["~"]]:
"""
Enhanced version of isinstance() that returns:
- True: if object is exactly of the specified type
- "~": if object follows numeric tower rules but isn't exact type
- False: if object is not of the specified type or numeric tower rules don't apply
Examples:
>>> enhanced_isinstance(3, int) # Returns True
>>> enhanced_isinstance(3, float) # Returns "~"
>>> enhanced_isinstance(True, int) # Returns "~"
>>> enhanced_isinstance(3, bool) # Returns False
>>> enhanced_isinstance(True, bool) # Returns True
"""
# Handle exact match first
if isinstance(obj, classinfo):
return True

# Handle numeric tower cases
if isinstance(obj, bool):
if classinfo in (int, float, complex):
return "~"
elif isinstance(obj, int) and not isinstance(obj, bool): # explicit bool check
if classinfo in (float, complex):
return "~"
elif isinstance(obj, float):
if classinfo is complex:
return "~"

return False


NoneType = type(None)


Expand Down
42 changes: 29 additions & 13 deletions src/tyro/constructors/_primitive_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,15 @@ class PrimitiveConstructorSpec(Generic[T]):
instance_from_str: Callable[[list[str]], T]
"""Given a list of string arguments, construct an instance of the type. The
length of the list will match the value of nargs."""
is_instance: Callable[[Any], bool]
is_instance: Callable[[Any], bool | Literal["~"]]
"""Given an object instance, does it match this primitive type? This is
used for specific help messages when both a union type is present and a
default is provided."""
default is provided.
Can return "~" to signify that an instance is a "fuzzy" match, and should
only be used if there are no other matches. This is used for numeric tower
support.
"""
str_from_instance: Callable[[T], list[str]]
"""Convert an instance to a list of string arguments that would construct
the instance. This is used for help messages when a default is provided."""
Expand All @@ -124,11 +129,12 @@ def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
return None
raise UnsupportedTypeAnnotationError("`Any` is not a parsable type.")

# HACK: this is for code that uses `tyro.conf.arg(constructor=json.loads)`.
# We're going to deprecate this syntax (the constructor= argument in
# tyro.conf.arg), but there is code that lives in the wild that relies
# on the behavior so we'll do our best not to break it.
vanilla_types = (int, str, float, bytes, json.loads)
# HACK (json.loads): this is for code that uses
# `tyro.conf.arg(constructor=json.loads)`. We're going to deprecate this
# syntax (the constructor= argument in tyro.conf.arg), but there is code
# that lives in the wild that relies on the behavior so we'll do our best
# not to break it.
vanilla_types = (int, str, float, complex, bytes, bytearray, json.loads)

@registry.primitive_rule
def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
Expand All @@ -142,10 +148,11 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None
if type_info.type is bytes
else type_info.type(args[0])
),
# Numeric tower in Python is weird...
is_instance=lambda x: isinstance(x, (int, float))
if type_info.type is float
else isinstance(x, type_info.type),
# issubclass(type(x), y) here is preferable over isinstance(x, y)
# due to quirks in the numeric tower.
is_instance=lambda x: _resolver.isinstance_with_fuzzy_numeric_tower(
x, type_info.type
),
str_from_instance=lambda instance: [str(instance)],
)

Expand Down Expand Up @@ -582,7 +589,7 @@ def union_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:

# General unions, eg Union[int, bool]. We'll try to convert these from left to
# right.
option_specs = []
option_specs: list[PrimitiveConstructorSpec] = []
choices: tuple[str, ...] | None = ()
nargs: int | Literal["*"] = 1
first = True
Expand Down Expand Up @@ -646,9 +653,18 @@ def union_instantiator(strings: List[str]) -> Any:
)

def str_from_instance(instance: Any) -> List[str]:
fuzzy_match = None
for option_spec in option_specs:
if option_spec.is_instance(instance):
is_instance = option_spec.is_instance(instance)
if is_instance is True:
return option_spec.str_from_instance(instance)
elif is_instance == "~":
fuzzy_match = option_spec

# If we get here, we have a fuzzy match.
if fuzzy_match is not None:
return fuzzy_match.str_from_instance(instance)

assert False, f"could not match default value {instance} with any types in union {options}"

return PrimitiveConstructorSpec(
Expand Down
21 changes: 21 additions & 0 deletions tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,3 +942,24 @@ def main(dt: datetime.time) -> datetime.time:
# Invalid hour value.
with pytest.raises(SystemExit):
tyro.cli(main, args=["--dt", "25:00:00"])


def test_numeric_tower() -> None:
@dataclasses.dataclass(frozen=True)
class NumericTower:
a: Union[complex, str] = 3.0
b: Union[bytearray, str] = dataclasses.field(
default_factory=lambda: bytearray(b"123")
)
c: Union[complex, str] = True
d: Union[int, complex] = False
e: Union[float, str] = 3

assert tyro.cli(NumericTower, args=[]) == NumericTower(3.0)
assert tyro.cli(NumericTower, args="--a 1+3j".split(" ")) == NumericTower(1 + 3j)
assert tyro.cli(NumericTower, args="--c False".split(" ")) == NumericTower(
c="False"
)
assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2)
with pytest.raises(SystemExit):
tyro.cli(NumericTower, args="--d False".split(" "))
2 changes: 1 addition & 1 deletion tests/test_py311_generated/test_collections_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def main(x: Dict = {"int": 5, "str": "5"}):

def test_dict_optional() -> None:
# In this case, the `None` is ignored.
def main(x: Optional[Dict[str, int]] = {"three": 3, "five": 5}):
def main(x: Optional[Dict[str, float]] = {"three": 3, "five": 5}):
return x

assert tyro.cli(main, args=[]) == {"three": 3, "five": 5}
Expand Down
21 changes: 21 additions & 0 deletions tests/test_py311_generated/test_dcargs_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,24 @@ def main(dt: datetime.time) -> datetime.time:
# Invalid hour value.
with pytest.raises(SystemExit):
tyro.cli(main, args=["--dt", "25:00:00"])


def test_numeric_tower() -> None:
@dataclasses.dataclass(frozen=True)
class NumericTower:
a: complex | str = 3.0
b: bytearray | str = dataclasses.field(
default_factory=lambda: bytearray(b"123")
)
c: complex | str = True
d: int | complex = False
e: float | str = 3

assert tyro.cli(NumericTower, args=[]) == NumericTower(3.0)
assert tyro.cli(NumericTower, args="--a 1+3j".split(" ")) == NumericTower(1 + 3j)
assert tyro.cli(NumericTower, args="--c False".split(" ")) == NumericTower(
c="False"
)
assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2)
with pytest.raises(SystemExit):
tyro.cli(NumericTower, args="--d False".split(" "))

0 comments on commit 42a7f80

Please sign in to comment.