From 754252c17aee67152bb08dcfb012718ed244e712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alicia=20Boya=20Garc=C3=ADa?= Date: Thu, 16 Jan 2025 22:20:02 +0100 Subject: [PATCH] TYP: preserve shape in ndarray.astype() This patch changes the return type in astype() from NDArray to ndarray so that shape information is preserved and adds tests for it. --- numpy/__init__.pyi | 4 ++-- numpy/typing/tests/data/reveal/ndarray_conversion.pyi | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index ce06ef8a587f..d76c173e52e5 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -2496,7 +2496,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]): casting: _CastingKind = ..., subok: builtins.bool = ..., copy: builtins.bool | _CopyMode = ..., - ) -> NDArray[_SCT]: ... + ) -> ndarray[_ShapeT_co, dtype[_SCT]]: ... @overload def astype( self, @@ -2505,7 +2505,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]): casting: _CastingKind = ..., subok: builtins.bool = ..., copy: builtins.bool | _CopyMode = ..., - ) -> NDArray[Any]: ... + ) -> ndarray[_ShapeT_co, np.dtype[Any]]: ... @overload def view(self) -> Self: ... diff --git a/numpy/typing/tests/data/reveal/ndarray_conversion.pyi b/numpy/typing/tests/data/reveal/ndarray_conversion.pyi index 789585ec963b..01c2c546dd8c 100644 --- a/numpy/typing/tests/data/reveal/ndarray_conversion.pyi +++ b/numpy/typing/tests/data/reveal/ndarray_conversion.pyi @@ -11,6 +11,7 @@ i4_2d: np.ndarray[tuple[int, int], np.dtype[np.int32]] f8_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float64]] cG_4d: np.ndarray[tuple[int, int, int, int], np.dtype[np.clongdouble]] i0_nd: npt.NDArray[np.int_] +uncertain_dtype: np.int32 | np.float64 | np.str_ # item assert_type(i0_nd.item(), int) @@ -50,6 +51,10 @@ assert_type(i0_nd.astype(np.float64, "K", "unsafe", True, True), npt.NDArray[np. assert_type(np.astype(i0_nd, np.float64), npt.NDArray[np.float64]) +assert_type(i4_2d.astype(np.uint16), np.ndarray[tuple[int, int], np.dtype[np.uint16]]) +assert_type(f8_3d.astype(np.int16), np.ndarray[tuple[int, int, int], np.dtype[np.int16]]) +assert_type(i4_2d.astype(uncertain_dtype), np.ndarray[tuple[int, int], np.dtype[np.generic[Any]]]) + # byteswap assert_type(i0_nd.byteswap(), npt.NDArray[np.int_]) assert_type(i0_nd.byteswap(True), npt.NDArray[np.int_])