Skip to content

Commit

Permalink
TYP: preserve shape-type in ndarray.astype()
Browse files Browse the repository at this point in the history
This patch changes the return type in astype() from NDArray to ndarray
so that shape information is preserved and adds tests for it.

Similar changes are added to np.astype() for consistency.
  • Loading branch information
ntrrgc committed Jan 16, 2025
1 parent bbf4836 commit b866d68
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2496,7 +2496,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
casting: _CastingKind = ...,
subok: builtins.bool = ...,
copy: builtins.bool | _CopyMode = ...,
) -> NDArray[_SCT]: ...
) -> ndarray[_ShapeT_co, dtype[_SCT]]: ...
@overload
def astype(
self,
Expand All @@ -2505,7 +2505,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
casting: _CastingKind = ...,
subok: builtins.bool = ...,
copy: builtins.bool | _CopyMode = ...,
) -> NDArray[Any]: ...
) -> ndarray[_ShapeT_co, dtype[Any]]: ...

@overload
def view(self) -> Self: ...
Expand Down
8 changes: 4 additions & 4 deletions numpy/_core/numeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -872,15 +872,15 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> bool: ...

@overload
def astype(
x: NDArray[Any],
x: ndarray[_ShapeType, dtype[Any]],
dtype: _DTypeLike[_SCT],
copy: bool = ...,
device: None | L["cpu"] = ...,
) -> NDArray[_SCT]: ...
) -> ndarray[_ShapeType, dtype[_SCT]]: ...
@overload
def astype(
x: NDArray[Any],
x: ndarray[_ShapeType, dtype[Any]],
dtype: DTypeLike,
copy: bool = ...,
device: None | L["cpu"] = ...,
) -> NDArray[Any]: ...
) -> ndarray[_ShapeType, dtype[Any]]: ...
8 changes: 8 additions & 0 deletions numpy/typing/tests/data/reveal/ndarray_conversion.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ i4_2d: np.ndarray[tuple[int, int], np.dtype[np.int32]]
f8_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float64]]
cG_4d: np.ndarray[tuple[int, int, int, int], np.dtype[np.clongdouble]]
i0_nd: npt.NDArray[np.int_]
uncertain_dtype: np.int32 | np.float64 | np.str_

# item
assert_type(i0_nd.item(), int)
Expand Down Expand Up @@ -50,6 +51,13 @@ assert_type(i0_nd.astype(np.float64, "K", "unsafe", True, True), npt.NDArray[np.

assert_type(np.astype(i0_nd, np.float64), npt.NDArray[np.float64])

assert_type(i4_2d.astype(np.uint16), np.ndarray[tuple[int, int], np.dtype[np.uint16]])
assert_type(np.astype(i4_2d, np.uint16), np.ndarray[tuple[int, int], np.dtype[np.uint16]])
assert_type(f8_3d.astype(np.int16), np.ndarray[tuple[int, int, int], np.dtype[np.int16]])
assert_type(np.astype(f8_3d, np.int16), np.ndarray[tuple[int, int, int], np.dtype[np.int16]])
assert_type(i4_2d.astype(uncertain_dtype), np.ndarray[tuple[int, int], np.dtype[np.generic[Any]]])
assert_type(np.astype(i4_2d, uncertain_dtype), np.ndarray[tuple[int, int], np.dtype[Any]])

# byteswap
assert_type(i0_nd.byteswap(), npt.NDArray[np.int_])
assert_type(i0_nd.byteswap(True), npt.NDArray[np.int_])
Expand Down

0 comments on commit b866d68

Please sign in to comment.