Skip to content

Commit

Permalink
TYP: preserve shape in ndarray.astype()
Browse files Browse the repository at this point in the history
This patch changes the return type in astype() from NDArray to ndarray
so that shape information is preserved and adds tests for it.
  • Loading branch information
ntrrgc committed Jan 16, 2025
1 parent bbf4836 commit 754252c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2496,7 +2496,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
casting: _CastingKind = ...,
subok: builtins.bool = ...,
copy: builtins.bool | _CopyMode = ...,
) -> NDArray[_SCT]: ...
) -> ndarray[_ShapeT_co, dtype[_SCT]]: ...
@overload
def astype(
self,
Expand All @@ -2505,7 +2505,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
casting: _CastingKind = ...,
subok: builtins.bool = ...,
copy: builtins.bool | _CopyMode = ...,
) -> NDArray[Any]: ...
) -> ndarray[_ShapeT_co, np.dtype[Any]]: ...

@overload
def view(self) -> Self: ...
Expand Down
5 changes: 5 additions & 0 deletions numpy/typing/tests/data/reveal/ndarray_conversion.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ i4_2d: np.ndarray[tuple[int, int], np.dtype[np.int32]]
f8_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float64]]
cG_4d: np.ndarray[tuple[int, int, int, int], np.dtype[np.clongdouble]]
i0_nd: npt.NDArray[np.int_]
uncertain_dtype: np.int32 | np.float64 | np.str_

# item
assert_type(i0_nd.item(), int)
Expand Down Expand Up @@ -50,6 +51,10 @@ assert_type(i0_nd.astype(np.float64, "K", "unsafe", True, True), npt.NDArray[np.

assert_type(np.astype(i0_nd, np.float64), npt.NDArray[np.float64])

assert_type(i4_2d.astype(np.uint16), np.ndarray[tuple[int, int], np.dtype[np.uint16]])
assert_type(f8_3d.astype(np.int16), np.ndarray[tuple[int, int, int], np.dtype[np.int16]])
assert_type(i4_2d.astype(uncertain_dtype), np.ndarray[tuple[int, int], np.dtype[np.generic[Any]]])

# byteswap
assert_type(i0_nd.byteswap(), npt.NDArray[np.int_])
assert_type(i0_nd.byteswap(True), npt.NDArray[np.int_])
Expand Down

0 comments on commit 754252c

Please sign in to comment.