diff --git a/.circleci/cimodel/data/binary_build_data.py b/.circleci/cimodel/data/binary_build_data.py index 51c2bb4312235b..1e1863cae53a53 100644 --- a/.circleci/cimodel/data/binary_build_data.py +++ b/.circleci/cimodel/data/binary_build_data.py @@ -126,6 +126,10 @@ def __init__(self, parent, package_format, python_versions): self.props["python_versions"] = python_versions self.props["package_format"] = package_format + # XXX Disabling conda for 11.3 as there's currently no appropriate cudatoolkit available + if package_format == "conda": + self.props["gpu_versions"] = filter(lambda x: x != "cuda113", self.find_prop("gpu_versions")) + def get_children(self): if self.find_prop("os_name") == "linux": return [LinuxGccConfigNode(self, v) for v in LINUX_GCC_CONFIG_VARIANTS[self.find_prop("package_format")]] diff --git a/.circleci/cimodel/data/dimensions.py b/.circleci/cimodel/data/dimensions.py index e4a86dc4e96f02..bdcc454c204520 100644 --- a/.circleci/cimodel/data/dimensions.py +++ b/.circleci/cimodel/data/dimensions.py @@ -3,6 +3,7 @@ CUDA_VERSIONS = [ "102", "111", + "113", ] ROCM_VERSIONS = [ diff --git a/.circleci/cimodel/data/windows_build_definitions.py b/.circleci/cimodel/data/windows_build_definitions.py index c93d968b6f0e7b..c6728d500efe33 100644 --- a/.circleci/cimodel/data/windows_build_definitions.py +++ b/.circleci/cimodel/data/windows_build_definitions.py @@ -147,16 +147,13 @@ def render(self): WindowsJob(None, _VC2019, CudaVersion(10, 1), master_only=True), WindowsJob(1, _VC2019, CudaVersion(10, 1), master_only=True), WindowsJob(2, _VC2019, CudaVersion(10, 1), master_only=True), + # VS2019 CUDA-10.1 force on cpu + WindowsJob(1, _VC2019, CudaVersion(10, 1), force_on_cpu=True, master_only=True), # VS2019 CUDA-11.1 WindowsJob(None, _VC2019, CudaVersion(11, 1)), WindowsJob(1, _VC2019, CudaVersion(11, 1), master_only=True), WindowsJob(2, _VC2019, CudaVersion(11, 1), master_only=True), WindowsJob('_azure_multi_gpu', _VC2019, CudaVersion(11, 1), multi_gpu=True, master_and_nightly=True), - # VS2019 CPU-only - WindowsJob(None, _VC2019, None), - WindowsJob(1, _VC2019, None), - WindowsJob(2, _VC2019, None), - WindowsJob(1, _VC2019, CudaVersion(10, 1), force_on_cpu=True, master_only=True), ] diff --git a/.circleci/config.yml b/.circleci/config.yml index a242ccf193f089..1e012438af9ecb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2258,6 +2258,50 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda111" + - binary_linux_build: + name: binary_linux_manywheel_3_6m_cu113_devtoolset7_nightly_build + build_environment: "manywheel 3.6m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-cuda113" + - binary_linux_build: + name: binary_linux_manywheel_3_7m_cu113_devtoolset7_nightly_build + build_environment: "manywheel 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-cuda113" + - binary_linux_build: + name: binary_linux_manywheel_3_8m_cu113_devtoolset7_nightly_build + build_environment: "manywheel 3.8m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-cuda113" + - binary_linux_build: + name: binary_linux_manywheel_3_9m_cu113_devtoolset7_nightly_build + build_environment: "manywheel 3.9m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-cuda113" - binary_linux_build: name: binary_linux_manywheel_3_6m_rocm4_0_1_devtoolset7_nightly_build build_environment: "manywheel 3.6m rocm4.0.1 devtoolset7" @@ -2666,6 +2710,54 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "static-without-deps" docker_image: "pytorch/manylinux-cuda111" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-with-deps_build + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-with-deps" + docker_image: "pytorch/manylinux-cuda113" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-without-deps_build + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-without-deps" + docker_image: "pytorch/manylinux-cuda113" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-with-deps_build + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-with-deps" + docker_image: "pytorch/manylinux-cuda113" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-without-deps_build + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-without-deps" + docker_image: "pytorch/manylinux-cuda113" - binary_linux_build: name: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_nightly_shared-with-deps_build build_environment: "libtorch 3.7m cpu gcc5.4_cxx11-abi" @@ -2810,6 +2902,54 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "static-without-deps" docker_image: "pytorch/libtorch-cxx11-builder:cuda111" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-with-deps_build + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-with-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-without-deps_build + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-without-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-with-deps_build + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-with-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + - binary_linux_build: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-without-deps_build + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-without-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" - binary_mac_build: name: binary_macos_wheel_3_6_cpu_nightly_build build_environment: "wheel 3.6 cpu" @@ -3060,6 +3200,46 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_6_cu113_nightly_build + build_environment: "wheel 3.6 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_7_cu113_nightly_build + build_environment: "wheel 3.7 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_8_cu113_nightly_build + build_environment: "wheel 3.8 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_9_cu113_nightly_build + build_environment: "wheel 3.9 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_conda_3_6_cpu_nightly_build build_environment: "conda 3.6 cpu" @@ -3210,6 +3390,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_libtorch_3_7_cu113_debug_nightly_build + build_environment: "libtorch 3.7 cu113 debug" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_libtorch_3_7_cpu_release_nightly_build build_environment: "libtorch 3.7 cpu release" @@ -3240,6 +3430,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_libtorch_3_7_cu113_release_nightly_build + build_environment: "libtorch 3.7 cu113 release" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_linux_test: name: binary_linux_manywheel_3_6m_cpu_devtoolset7_nightly_test build_environment: "manywheel 3.6m cpu devtoolset7" @@ -3412,6 +3612,66 @@ workflows: docker_image: "pytorch/manylinux-cuda111" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_6m_cu113_devtoolset7_nightly_test + build_environment: "manywheel 3.6m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_6m_cu113_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_7m_cu113_devtoolset7_nightly_test + build_environment: "manywheel 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_7m_cu113_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_8m_cu113_devtoolset7_nightly_test + build_environment: "manywheel 3.8m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_8m_cu113_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_9m_cu113_devtoolset7_nightly_test + build_environment: "manywheel 3.9m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_9m_cu113_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: name: binary_linux_manywheel_3_6m_rocm4_0_1_devtoolset7_nightly_test build_environment: "manywheel 3.6m rocm4.0.1 devtoolset7" @@ -3948,6 +4208,70 @@ workflows: docker_image: "pytorch/manylinux-cuda111" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-with-deps_test + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-with-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-with-deps_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-without-deps_test + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-without-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-without-deps_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-with-deps_test + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-with-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-with-deps_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-without-deps_test + build_environment: "libtorch 3.7m cu113 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-without-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-without-deps_build + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: name: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_nightly_shared-with-deps_test build_environment: "libtorch 3.7m cpu gcc5.4_cxx11-abi" @@ -4132,9 +4456,9 @@ workflows: docker_image: "pytorch/libtorch-cxx11-builder:cuda111" use_cuda_docker_runtime: "1" resource_class: gpu.medium - - binary_windows_test: - name: binary_windows_wheel_3_6_cpu_nightly_test - build_environment: "wheel 3.6 cpu" + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-with-deps_test + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" filters: branches: only: @@ -4142,11 +4466,75 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-with-deps" requires: - - binary_windows_wheel_3_6_cpu_nightly_build - - binary_windows_test: - name: binary_windows_wheel_3_7_cpu_nightly_test - build_environment: "wheel 3.7 cpu" + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-with-deps_build + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-without-deps_test + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "shared-without-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-without-deps_build + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-with-deps_test + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-with-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-with-deps_build + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-without-deps_test + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + libtorch_variant: "static-without-deps" + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-without-deps_build + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_windows_test: + name: binary_windows_wheel_3_6_cpu_nightly_test + build_environment: "wheel 3.6 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_6_cpu_nightly_build + - binary_windows_test: + name: binary_windows_wheel_3_7_cpu_nightly_test + build_environment: "wheel 3.7 cpu" filters: branches: only: @@ -4284,6 +4672,58 @@ workflows: requires: - binary_windows_wheel_3_9_cu111_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_6_cu113_nightly_test + build_environment: "wheel 3.6 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_6_cu113_nightly_build + executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_7_cu113_nightly_test + build_environment: "wheel 3.7 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_7_cu113_nightly_build + executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_8_cu113_nightly_test + build_environment: "wheel 3.8 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_8_cu113_nightly_build + executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_9_cu113_nightly_test + build_environment: "wheel 3.9 cu113" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_9_cu113_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_conda_3_6_cpu_nightly_test build_environment: "conda 3.6 cpu" @@ -4474,6 +4914,19 @@ workflows: requires: - binary_windows_libtorch_3_7_cu111_debug_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_libtorch_3_7_cu113_debug_nightly_test + build_environment: "libtorch 3.7 cu113 debug" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_libtorch_3_7_cu113_debug_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_libtorch_3_7_cpu_release_nightly_test build_environment: "libtorch 3.7 cpu release" @@ -4512,6 +4965,19 @@ workflows: requires: - binary_windows_libtorch_3_7_cu111_release_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_libtorch_3_7_cu113_release_nightly_test + build_environment: "libtorch 3.7 cu113 release" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_libtorch_3_7_cu113_release_nightly_build + executor: windows-with-nvidia-gpu - binary_upload: name: binary_linux_manywheel_3_6m_cpu_devtoolset7_nightly_upload context: org-member @@ -4680,6 +5146,62 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel upload_subfolder: cu111 + - binary_upload: + name: binary_linux_manywheel_3_6m_cu113_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_6m_cu113_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_manywheel_3_7m_cu113_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_7m_cu113_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_manywheel_3_8m_cu113_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_8m_cu113_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_manywheel_3_9m_cu113_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_9m_cu113_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: cu113 - binary_upload: name: binary_linux_manywheel_3_6m_rocm4_0_1_devtoolset7_nightly_upload context: org-member @@ -5184,6 +5706,62 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch upload_subfolder: cu111 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-with-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-with-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-without-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-without-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-with-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-with-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-without-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-without-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 - binary_upload: name: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_nightly_shared-with-deps_upload context: org-member @@ -5352,6 +5930,62 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch upload_subfolder: cu111 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-with-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-with-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-without-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-without-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-with-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-with-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 + - binary_upload: + name: binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-without-deps_upload + context: org-member + requires: + - binary_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-without-deps_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 - binary_upload: name: binary_macos_wheel_3_6_cpu_nightly_upload context: org-member @@ -5702,6 +6336,62 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: wheel upload_subfolder: cu111 + - binary_upload: + name: binary_windows_wheel_3_6_cu113_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_6_cu113_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu113 + - binary_upload: + name: binary_windows_wheel_3_7_cu113_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_7_cu113_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu113 + - binary_upload: + name: binary_windows_wheel_3_8_cu113_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_8_cu113_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu113 + - binary_upload: + name: binary_windows_wheel_3_9_cu113_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_9_cu113_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu113 - binary_upload: name: binary_windows_conda_3_6_cpu_nightly_upload context: org-member @@ -5912,6 +6602,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch upload_subfolder: cu111 + - binary_upload: + name: binary_windows_libtorch_3_7_cu113_debug_nightly_upload + context: org-member + requires: + - binary_windows_libtorch_3_7_cu113_debug_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu113 - binary_upload: name: binary_windows_libtorch_3_7_cpu_release_nightly_upload context: org-member @@ -5927,10 +6631,24 @@ workflows: package_type: libtorch upload_subfolder: cpu - binary_upload: - name: binary_windows_libtorch_3_7_cu102_release_nightly_upload + name: binary_windows_libtorch_3_7_cu102_release_nightly_upload + context: org-member + requires: + - binary_windows_libtorch_3_7_cu102_release_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: libtorch + upload_subfolder: cu102 + - binary_upload: + name: binary_windows_libtorch_3_7_cu111_release_nightly_upload context: org-member requires: - - binary_windows_libtorch_3_7_cu102_release_nightly_test + - binary_windows_libtorch_3_7_cu111_release_nightly_test filters: branches: only: @@ -5939,12 +6657,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch - upload_subfolder: cu102 + upload_subfolder: cu111 - binary_upload: - name: binary_windows_libtorch_3_7_cu111_release_nightly_upload + name: binary_windows_libtorch_3_7_cu113_release_nightly_upload context: org-member requires: - - binary_windows_libtorch_3_7_cu111_release_nightly_test + - binary_windows_libtorch_3_7_cu113_release_nightly_test filters: branches: only: @@ -5953,7 +6671,7 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch - upload_subfolder: cu111 + upload_subfolder: cu113 when: << pipeline.parameters.run_binary_tests >> build: jobs: @@ -6946,6 +7664,24 @@ workflows: vc_product: BuildTools vc_version: "" vc_year: "2019" + - pytorch_windows_test: + build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 + cuda_version: "10.1" + filters: + branches: + only: + - master + - /ci-all\/.*/ + - /release\/.*/ + name: pytorch_windows_vs2019_py36_cuda10.1_on_cpu_test1 + python_version: "3.6" + requires: + - pytorch_windows_vs2019_py36_cuda10.1_build + test_name: pytorch-windows-test1 + use_cuda: "0" + vc_product: BuildTools + vc_version: "" + vc_year: "2019" - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 cuda_version: "11.1" @@ -7006,57 +7742,6 @@ workflows: name: pytorch_windows_vs2019_py36_cuda11.1_test_azure_multi_gpu requires: - pytorch_windows_vs2019_py36_cuda11.1_build - - pytorch_windows_build: - build_environment: pytorch-win-vs2019-cpu-py3 - cuda_version: cpu - name: pytorch_windows_vs2019_py36_cpu_build - python_version: "3.6" - use_cuda: "0" - vc_product: BuildTools - vc_version: "" - vc_year: "2019" - - pytorch_windows_test: - build_environment: pytorch-win-vs2019-cpu-py3 - cuda_version: cpu - name: pytorch_windows_vs2019_py36_cpu_test1 - python_version: "3.6" - requires: - - pytorch_windows_vs2019_py36_cpu_build - test_name: pytorch-windows-test1 - use_cuda: "0" - vc_product: BuildTools - vc_version: "" - vc_year: "2019" - - pytorch_windows_test: - build_environment: pytorch-win-vs2019-cpu-py3 - cuda_version: cpu - name: pytorch_windows_vs2019_py36_cpu_test2 - python_version: "3.6" - requires: - - pytorch_windows_vs2019_py36_cpu_build - test_name: pytorch-windows-test2 - use_cuda: "0" - vc_product: BuildTools - vc_version: "" - vc_year: "2019" - - pytorch_windows_test: - build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10.1" - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - name: pytorch_windows_vs2019_py36_cuda10.1_on_cpu_test1 - python_version: "3.6" - requires: - - pytorch_windows_vs2019_py36_cuda10.1_build - test_name: pytorch-windows-test1 - use_cuda: "0" - vc_product: BuildTools - vc_version: "" - vc_year: "2019" - update_s3_htmls: context: org-member filters: @@ -7200,6 +7885,54 @@ workflows: docker_image: "pytorch/manylinux-cuda111" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_6m_cu113_devtoolset7_nightly + build_environment: "manywheel 3.6m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_7m_cu113_devtoolset7_nightly + build_environment: "manywheel 3.7m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_8m_cu113_devtoolset7_nightly + build_environment: "manywheel 3.8m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_9m_cu113_devtoolset7_nightly + build_environment: "manywheel 3.9m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_manywheel_3_6m_rocm4_0_1_devtoolset7_nightly build_environment: "manywheel 3.6m rocm4.0.1 devtoolset7" @@ -7628,6 +8361,58 @@ workflows: docker_image: "pytorch/manylinux-cuda111" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-with-deps + build_environment: "libtorch 3.7m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "shared-with-deps" + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_devtoolset7_nightly_shared-without-deps + build_environment: "libtorch 3.7m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "shared-without-deps" + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-with-deps + build_environment: "libtorch 3.7m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "static-with-deps" + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_devtoolset7_nightly_static-without-deps + build_environment: "libtorch 3.7m cu113 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "static-without-deps" + docker_image: "pytorch/manylinux-cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_nightly_shared-with-deps build_environment: "libtorch 3.7m cpu gcc5.4_cxx11-abi" @@ -7776,6 +8561,58 @@ workflows: docker_image: "pytorch/libtorch-cxx11-builder:cuda111" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-with-deps + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "shared-with-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_shared-without-deps + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "shared-without-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-with-deps + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "static-with-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_libtorch_3_7m_cu113_gcc5_4_cxx11-abi_nightly_static-without-deps + build_environment: "libtorch 3.7m cu113 gcc5.4_cxx11-abi" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + libtorch_variant: "static-without-deps" + docker_image: "pytorch/libtorch-cxx11-builder:cuda113" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_mac_test: name: smoke_macos_wheel_3_6_cpu_nightly build_environment: "wheel 3.6 cpu" @@ -7973,6 +8810,46 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_6_cu113_nightly + build_environment: "wheel 3.6 cu113" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_7_cu113_nightly + build_environment: "wheel 3.7 cu113" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_8_cu113_nightly + build_environment: "wheel 3.8 cu113" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_9_cu113_nightly + build_environment: "wheel 3.9 cu113" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_conda_3_6_cpu_nightly build_environment: "conda 3.6 cpu" @@ -8118,6 +8995,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_libtorch_3_7_cu113_debug_nightly + build_environment: "libtorch 3.7 cu113 debug" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_libtorch_3_7_cpu_release_nightly build_environment: "libtorch 3.7 cpu release" @@ -8147,6 +9034,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_libtorch_3_7_cu113_release_nightly + build_environment: "libtorch 3.7 cu113 release" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu when: << pipeline.parameters.run_build >> master_build: jobs: @@ -8409,6 +9306,18 @@ workflows: vc_product: BuildTools vc_version: "" vc_year: "2019" + - pytorch_windows_test: + build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 + cuda_version: "10.1" + name: pytorch_windows_vs2019_py36_cuda10.1_on_cpu_test1 + python_version: "3.6" + requires: + - pytorch_windows_vs2019_py36_cuda10.1_build + test_name: pytorch-windows-test1 + use_cuda: "0" + vc_product: BuildTools + vc_version: "" + vc_year: "2019" - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 cuda_version: "11.1" @@ -8448,18 +9357,6 @@ workflows: name: pytorch_windows_vs2019_py36_cuda11.1_test_azure_multi_gpu requires: - pytorch_windows_vs2019_py36_cuda11.1_build - - pytorch_windows_test: - build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10.1" - name: pytorch_windows_vs2019_py36_cuda10.1_on_cpu_test1 - python_version: "3.6" - requires: - - pytorch_windows_vs2019_py36_cuda10.1_build - test_name: pytorch-windows-test1 - use_cuda: "0" - vc_product: BuildTools - vc_version: "" - vc_year: "2019" when: << pipeline.parameters.run_master_build >> scheduled-ci: triggers: diff --git a/.circleci/scripts/binary_linux_build.sh b/.circleci/scripts/binary_linux_build.sh index e36d06906246dc..055bba4d2f5d57 100755 --- a/.circleci/scripts/binary_linux_build.sh +++ b/.circleci/scripts/binary_linux_build.sh @@ -4,10 +4,14 @@ echo "RUNNING ON $(uname -a) WITH $(nproc) CPUS AND $(free -m)" set -eux -o pipefail source /env -# Defaults here so they can be changed in one place -export MAX_JOBS=${MAX_JOBS:-$(( $(nproc) - 2 ))} +# Because most Circle executors only have 20 CPUs, using more causes OOMs w/ Ninja and nvcc parallelization +MEMORY_LIMIT_MAX_JOBS=18 +NUM_CPUS=$(( $(nproc) - 2 )) -if [[ "${DESIRED_CUDA}" == "cu111" ]]; then +# Defaults here for **binary** linux builds so they can be changed in one place +export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} + +if [[ "${DESIRED_CUDA}" == "cu111" || "${DESIRED_CUDA}" == "cu113" ]]; then export BUILD_SPLIT_CUDA="ON" fi @@ -22,5 +26,9 @@ else build_script='manywheel/build.sh' fi +if [[ "$CIRCLE_BRANCH" == "master" ]] || [[ "$CIRCLE_BRANCH" == release/* ]]; then + export BUILD_DEBUG_INFO=1 +fi + # Build the package SKIP_ALL_TESTS=1 "/builder/$build_script" diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index d377093fd7ebd7..8d171e8101b4f7 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -38,6 +38,10 @@ if [[ "$DESIRED_CUDA" == "cu112" ]]; then EXTRA_CONDA_FLAGS="-c=conda-forge" fi +# Move debug wheels out of the the package dir so they don't get installed +mkdir -p /tmp/debug_final_pkgs +mv /final_pkgs/debug-*.zip /tmp/debug_final_pkgs || echo "no debug packages to move" + # Install the package # These network calls should not have 'retry's because they are installing # locally and aren't actually network calls diff --git a/.circleci/scripts/binary_windows_build.sh b/.circleci/scripts/binary_windows_build.sh index 04d31002106cc1..2d1f228ec5a005 100644 --- a/.circleci/scripts/binary_windows_build.sh +++ b/.circleci/scripts/binary_windows_build.sh @@ -15,7 +15,7 @@ else export VC_YEAR=2019 fi -if [[ "${DESIRED_CUDA}" == "cu111" ]]; then +if [[ "${DESIRED_CUDA}" == "cu111" || "${DESIRED_CUDA}" == "cu113" ]]; then export BUILD_SPLIT_CUDA="ON" fi diff --git a/.github/scale-config.yml b/.github/scale-config.yml index cfe100131d73fb..dee5222837ef32 100644 --- a/.github/scale-config.yml +++ b/.github/scale-config.yml @@ -32,3 +32,8 @@ runner_types: os: windows max_available: 200 disk_size: 256 + windows.8xlarge.nvidia.gpu: + instance_type: g3.8xlarge + os: windows + max_available: 25 + disk_size: 256 diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 304a2dd1979ed8..31bdeba4817116 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -71,6 +71,7 @@ def generate_workflow_file( PyTorchWindowsWorkflow( build_environment="pytorch-win-vs2019-cpu-py3", test_runner_type=WINDOWS_CPU_TEST_RUNNER, + on_pull_request=True ) ] diff --git a/.github/scripts/report_git_status.sh b/.github/scripts/report_git_status.sh index 357bacfecb2471..738fbcfd1451ae 100755 --- a/.github/scripts/report_git_status.sh +++ b/.github/scripts/report_git_status.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -CHANGES=$(git status --porcelain) +CHANGES=$(git status --porcelain "$1") echo "$CHANGES" -git diff +git diff "$1" [ -z "$CHANGES" ] diff --git a/.github/templates/windows_ci_workflow.yml.in b/.github/templates/windows_ci_workflow.yml.in index 5a1c602b40f229..9544b83138e27d 100644 --- a/.github/templates/windows_ci_workflow.yml.in +++ b/.github/templates/windows_ci_workflow.yml.in @@ -31,6 +31,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -73,6 +77,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | diff --git a/.github/workflows/cancel_redundant_workflows.yml b/.github/workflows/cancel_redundant_workflows.yml index a3dcf0d419a064..968ad48c15f7d5 100644 --- a/.github/workflows/cancel_redundant_workflows.yml +++ b/.github/workflows/cancel_redundant_workflows.yml @@ -9,6 +9,7 @@ on: - Linux CI (pytorch-linux-xenial-py3.6-gcc5.4) - Test tools - TorchBench CI (pytorch-linux-py3.7-cu102) + - Windows CI (pytorch-win-vs2019-cpu-py3) - clang-format jobs: cancel: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4bdbecebe64db0..e4c1bf1d330a2b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,27 +35,6 @@ jobs: run: | pip install ruamel.yaml==0.17.4 .github/scripts/lint_native_functions.py - - name: Extract scripts from GitHub Actions workflows - if: always() && steps.requirements.outcome == 'success' - run: | - # For local lints, remove the .extracted_scripts folder if it was already there - rm -rf .extracted_scripts - tools/extract_scripts.py --out=.extracted_scripts - - name: Install ShellCheck - id: install_shellcheck - if: always() - # https://github.com/koalaman/shellcheck/tree/v0.7.2#installing-a-pre-compiled-binary - run: | - set -x - scversion="v0.7.2" - wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv - sudo cp "shellcheck-${scversion}/shellcheck" /usr/bin/ - rm -r "shellcheck-${scversion}" - shellcheck --version - - name: Run ShellCheck - if: always() && steps.install_shellcheck.outcome == 'success' - run: | - tools/run_shellcheck.sh .jenkins/pytorch .extracted_scripts - name: Ensure correct trailing newlines if: always() && steps.requirements.outcome == 'success' run: | @@ -109,7 +88,7 @@ jobs: if: always() && steps.requirements.outcome == 'success' run: | set -eux - python torch/testing/check_kernel_launches.py |& tee "${GITHUB_WORKSPACE}"/cuda_kernel_launch_checks.txt + python torch/testing/_check_kernel_launches.py |& tee "${GITHUB_WORKSPACE}"/cuda_kernel_launch_checks.txt - name: Ensure no direct cub include if: always() run: | @@ -129,7 +108,7 @@ jobs: run: | python2 setup.py | grep "Python 2 has reached end-of-life and is no longer supported by PyTorch." - templates: + shellcheck: runs-on: ubuntu-18.04 steps: - name: Setup Python @@ -137,14 +116,42 @@ jobs: with: python-version: 3.x architecture: x64 + - name: Checkout PyTorch + uses: actions/checkout@v2 + - name: Install requirements + id: requirements + run: | + pip install -r requirements.txt - name: Install Jinja2 - run: pip install Jinja2 + run: | + pip install Jinja2==3.0.1 - name: Checkout PyTorch uses: actions/checkout@v2 - name: Regenerate workflows run: .github/scripts/generate_ci_workflows.py - name: Assert that regenerating the workflows didn't change them - run: .github/scripts/report_git_status.sh + run: .github/scripts/report_git_status.sh .github/workflows + - name: Install ShellCheck + id: install_shellcheck + if: always() + # https://github.com/koalaman/shellcheck/tree/v0.7.2#installing-a-pre-compiled-binary + run: | + set -x + scversion="v0.7.2" + wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv + sudo cp "shellcheck-${scversion}/shellcheck" /usr/bin/ + rm -r "shellcheck-${scversion}" + shellcheck --version + - name: Extract scripts from GitHub Actions workflows + if: always() && steps.install_shellcheck.outcome == 'success' + run: | + # For local lints, remove the .extracted_scripts folder if it was already there + rm -rf .extracted_scripts + tools/extract_scripts.py --out=.extracted_scripts + - name: Run ShellCheck + if: always() && steps.install_shellcheck.outcome == 'success' + run: | + tools/run_shellcheck.sh .extracted_scripts .jenkins/pytorch toc: runs-on: ubuntu-18.04 @@ -222,12 +229,10 @@ jobs: clang-tidy: if: github.event_name == 'pull_request' runs-on: ubuntu-18.04 + container: + # ubuntu18.04-cuda10.2-py3.6-tidy11 + image: ghcr.io/pytorch/cilint-clang-tidy:e2cfc57ce4fa3a257a4b78fdfdc2b065c167b9c5 steps: - - name: Setup Python - uses: actions/setup-python@v2 - with: - python-version: 3.x - architecture: x64 - name: Checkout PyTorch uses: actions/checkout@v2 with: @@ -236,47 +241,32 @@ jobs: env: HEAD_SHA: ${{ github.event.pull_request.head.sha }} run: | + cd "${GITHUB_WORKSPACE}" mkdir clang-tidy-output cd clang-tidy-output echo "$HEAD_SHA" > commit-sha.txt - - name: Install dependencies - run: | - set -eux - # Install CUDA - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin - sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600 - sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub - sudo add-apt-repository "deb http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /" - sudo apt-get update - sudo apt-get --no-install-recommends -y install cuda-toolkit-10-2 - # Install dependencies - pip install pyyaml typing_extensions - wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - sudo apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-11 main" - sudo apt-get update - sudo apt-get install -y clang-tidy-11 - sudo update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-11 1000 - name: Generate build files run: | + cd "${GITHUB_WORKSPACE}" set -eux git remote add upstream https://github.com/pytorch/pytorch git fetch upstream "$GITHUB_BASE_REF" - if [[ ! -d build ]]; then + if [ ! -d build ]; then git submodule update --init --recursive export USE_NCCL=0 export USE_DEPLOY=1 # We really only need compile_commands.json, so no need to build! - time python setup.py --cmake-only build + time python3 setup.py --cmake-only build # Generate ATen files. - time python -m tools.codegen.gen \ + time python3 -m tools.codegen.gen \ -s aten/src/ATen \ -d build/aten/src/ATen # Generate PyTorch files. - time python tools/setup_helpers/generate_code.py \ + time python3 tools/setup_helpers/generate_code.py \ --declarations-path build/aten/src/ATen/Declarations.yaml \ --native-functions-path aten/src/ATen/native/native_functions.yaml \ --nn-path aten/src @@ -286,6 +276,7 @@ jobs: BASE_SHA: ${{ github.event.pull_request.base.sha }} HEAD_SHA: ${{ github.event.pull_request.head.sha }} run: | + cd "${GITHUB_WORKSPACE}" set -eux # Run Clang-Tidy @@ -296,7 +287,7 @@ jobs: # /torch/csrc/generic/*.cpp is excluded because those files aren't actually built. # deploy/interpreter files are excluded due to using macros and other techniquies # that are not easily converted to accepted c++ - python tools/clang_tidy.py \ + python3 tools/clang_tidy.py \ --verbose \ --paths torch/csrc/ \ --diff "$BASE_SHA" \ diff --git a/.github/workflows/pytorch-win-vs2019-cpu-py3.yml b/.github/workflows/pytorch-win-vs2019-cpu-py3.yml index f2ac1d573a89a7..d3166967ed8c4e 100644 --- a/.github/workflows/pytorch-win-vs2019-cpu-py3.yml +++ b/.github/workflows/pytorch-win-vs2019-cpu-py3.yml @@ -4,6 +4,7 @@ name: Windows CI (pytorch-win-vs2019-cpu-py3) on: + pull_request: push: branches: - master @@ -29,6 +30,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -71,6 +76,10 @@ jobs: steps: - name: Checkout PyTorch uses: actions/checkout@v2 + - name: Install 7zip if not already installed + shell: powershell + run: | + choco install 7zip.install -y - name: Install Visual Studio 2019 toolchain shell: powershell run: | diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index bc309b8a54d832..c2be6c96b3e720 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -24,7 +24,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-mobile-code-analysis* ]]; then exec "$(dirname "${BASH_SOURCE[0]}")/build-mobile-code-analysis.sh" "$@" fi -if [[ "$BUILD_ENVIRONMENT" == pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7* ]]; then +if [[ "$BUILD_ENVIRONMENT" == pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7* ]]; then # Enabling DEPLOY build (embedded torch python interpreter, experimental) # only on one config for now, can expand later export USE_DEPLOY=ON diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 48840ad6c1bd58..3bce691f8cf85e 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -452,7 +452,7 @@ elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then # TODO: run some C++ tests echo "no-op at the moment" elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; then - if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1 ]]; then + if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-test1 ]]; then test_torch_deploy fi test_without_numpy diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f308a75f07273..4818b5012b576f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -351,6 +351,7 @@ option(USE_SYSTEM_CPUINFO "Use system-provided cpuinfo." OFF) option(USE_SYSTEM_SLEEF "Use system-provided sleef." OFF) option(USE_SYSTEM_GLOO "Use system-provided gloo." OFF) option(USE_SYSTEM_FP16 "Use system-provided fp16." OFF) +option(USE_SYSTEM_PYBIND11 "Use system-provided PyBind11." OFF) option(USE_SYSTEM_PTHREADPOOL "Use system-provided pthreadpool." OFF) option(USE_SYSTEM_PSIMD "Use system-provided psimd." OFF) option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF) @@ -371,6 +372,7 @@ if(USE_SYSTEM_LIBS) set(USE_SYSTEM_BENCHMARK ON) set(USE_SYSTEM_ONNX ON) set(USE_SYSTEM_XNNPACK ON) + set(USE_SYSTEM_PYBIND11 ON) endif() # Used when building Caffe2 through setup.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c5a363f8a5c0e3..02d4f63a354710 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -910,7 +910,7 @@ tensor([1., 2., 3., 4.], dtype=torch.float64) ``` GDB tries to automatically load `pytorch-gdb` thanks to the -[.gdbinit](.gdbinit) at the root of the pytorch repo. Howevever, auto-loadings is disabled by default, because of security reasons: +[.gdbinit](.gdbinit) at the root of the pytorch repo. However, auto-loadings is disabled by default, because of security reasons: ``` $ gdb diff --git a/Makefile b/Makefile index 59cdae8f9faed2..8d61fd0c543ec4 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ # This makefile does nothing but delegating the actual building to cmake. +PYTHON = python3 all: - @mkdir -p build && cd build && cmake .. $(shell python ./scripts/get_python_cmake_flags.py) && $(MAKE) + @mkdir -p build && cd build && cmake .. $(shell $(PYTHON) ./scripts/get_python_cmake_flags.py) && $(MAKE) local: @./scripts/build_local.sh @@ -31,13 +32,32 @@ generate-gha-workflows: .github/scripts/generate_ci_workflows.py $(MAKE) shellcheck-gha +shellcheck: + @$(PYTHON) tools/actions_local_runner.py \ + --file .github/workflows/lint.yml \ + --job 'shellcheck' \ + --step "Regenerate workflows" + @$(PYTHON) tools/actions_local_runner.py \ + --file .github/workflows/lint.yml \ + --job 'shellcheck' \ + --step "Assert that regenerating the workflows didn't change them" + @$(PYTHON) tools/actions_local_runner.py \ + --file .github/workflows/lint.yml \ + --job 'shellcheck' \ + --step 'Extract scripts from GitHub Actions workflows' + @$(PYTHON) tools/actions_local_runner.py \ + $(CHANGED_ONLY) \ + --job 'shellcheck' + setup_lint: - python tools/actions_local_runner.py --file .github/workflows/lint.yml \ - --job 'flake8-py3' --step 'Install dependencies' --no-quiet - python tools/actions_local_runner.py --file .github/workflows/lint.yml \ - --job 'cmakelint' --step 'Install dependencies' --no-quiet - python tools/actions_local_runner.py --file .github/workflows/lint.yml \ - --job 'mypy' --step 'Install dependencies' --no-quiet + $(PYTHON) tools/actions_local_runner.py --file .github/workflows/lint.yml \ + --job 'flake8-py3' --step 'Install dependencies' --no-quiet + $(PYTHON) tools/actions_local_runner.py --file .github/workflows/lint.yml \ + --job 'cmakelint' --step 'Install dependencies' --no-quiet + $(PYTHON) tools/actions_local_runner.py --file .github/workflows/lint.yml \ + --job 'mypy' --step 'Install dependencies' --no-quiet + $(PYTHON) tools/actions_local_runner.py --file .github/workflows/lint.yml \ + --job 'shellcheck' --step 'Install Jinja2' --no-quiet @if [ "$$(uname)" = "Darwin" ]; then \ if [ -z "$$(which brew)" ]; then \ @@ -46,20 +66,15 @@ setup_lint: fi; \ brew install shellcheck; \ else \ - python tools/actions_local_runner.py --file .github/workflows/lint.yml \ - --job 'quick-checks' --step 'Install ShellCheck' --no-quiet; \ + $(PYTHON) tools/actions_local_runner.py --file .github/workflows/lint.yml \ + --job 'shellcheck' --step 'Install ShellCheck' --no-quiet; \ fi pip install jinja2 quick_checks: - @python tools/actions_local_runner.py \ - --file .github/workflows/lint.yml \ - --job 'quick-checks' \ - --step 'Extract scripts from GitHub Actions workflows' - # TODO: This is broken when 'git config submodule.recurse' is 'true' since the # lints will descend into third_party submodules - @python tools/actions_local_runner.py \ + @$(PYTHON) tools/actions_local_runner.py \ --file .github/workflows/lint.yml \ --job 'quick-checks' \ --step 'Ensure no trailing spaces' \ @@ -70,23 +85,20 @@ quick_checks: --step 'Ensure no unqualified noqa' \ --step 'Ensure no unqualified type ignore' \ --step 'Ensure no direct cub include' \ - --step 'Run ShellCheck' \ --step 'Ensure correct trailing newlines' flake8: - @python tools/actions_local_runner.py \ - --file-filter '.py' \ + @$(PYTHON) tools/actions_local_runner.py \ $(CHANGED_ONLY) \ --job 'flake8-py3' mypy: - @python tools/actions_local_runner.py \ - --file-filter '.py' \ + @$(PYTHON) tools/actions_local_runner.py \ $(CHANGED_ONLY) \ --job 'mypy' cmakelint: - @python tools/actions_local_runner.py \ + @$(PYTHON) tools/actions_local_runner.py \ --file .github/workflows/lint.yml \ --job 'cmakelint' \ --step 'Run cmakelint' @@ -96,12 +108,12 @@ clang_tidy: exit 1 toc: - @python tools/actions_local_runner.py \ + @$(PYTHON) tools/actions_local_runner.py \ --file .github/workflows/lint.yml \ --job 'toc' \ --step "Regenerate ToCs and check that they didn't change" -lint: flake8 mypy quick_checks cmakelint generate-gha-workflows +lint: flake8 mypy quick_checks cmakelint shellcheck quicklint: CHANGED_ONLY=--changed-only -quicklint: mypy flake8 mypy quick_checks cmakelint generate-gha-workflows +quicklint: mypy flake8 mypy quick_checks cmakelint shellcheck diff --git a/README.md b/README.md index 0827c8c4e731a1..99a0209ee9bb5d 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to | Linux (ppc64le) GPU |
—
| [![Build Status](https://powerci.osuosl.org/job/pytorch-master-nightly-py3-linux-ppc64le-gpu/badge/icon)](https://powerci.osuosl.org/job/pytorch-master-nightly-py3-linux-ppc64le-gpu/) |
—
| | Linux (aarch64) CPU | [![Build Status](http://openlabtesting.org:15000/badge?project=pytorch%2Fpytorch&job_name=pytorch-arm64-build-daily-master-py36)](https://status.openlabtesting.org/builds/builds?project=pytorch%2Fpytorch&job_name=pytorch-arm64-build-daily-master-py36) | [![Build Status](http://openlabtesting.org:15000/badge?project=pytorch%2Fpytorch&job_name=pytorch-arm64-build-daily-master-py37)](https://status.openlabtesting.org/builds/builds?project=pytorch%2Fpytorch&job_name=pytorch-arm64-build-daily-master-py37) | [![Build Status](http://openlabtesting.org:15000/badge?project=pytorch%2Fpytorch&job_name=pytorch-arm64-build-daily-master-py38)](https://status.openlabtesting.org/builds/builds?project=pytorch%2Fpytorch&job_name=pytorch-arm64-build-daily-master-py38) | -See also the [ci.pytorch.org HUD](https://ezyang.github.io/pytorch-ci-hud/build/pytorch-master). +See also the [ci.pytorch.org HUD](https://hud.pytorch.org/build2/pytorch-master). ## More About PyTorch @@ -270,13 +270,13 @@ Sometimes there are regressions in new versions of Visual Studio, so it's best to use the same Visual Studio Version [16.8.5](https://github.com/pytorch/pytorch/blob/master/.circleci/scripts/vs_install.ps1) as Pytorch CI's. You can use Visual Studio Enterprise, Professional or Community though PyTorch CI uses Visual Studio BuildTools. -If you want to build legacy python code, please refert to [Building on legacy code and CUDA](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md#building-on-legacy-code-and-cuda) +If you want to build legacy python code, please refer to [Building on legacy code and CUDA](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md#building-on-legacy-code-and-cuda) Build with CPU It's fairly easy to build with CPU. -Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the buliding environment by tweaking `CMAKE_INCLUDE_PATH` and `LIB`. The instruction [here](https://github.com/pytorch/pytorch/blob/master/docs/source/notes/windows.rst#building-from-source) is an example for setting up both MKL and Intel OpenMP. Without these configuraions for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used. +Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking `CMAKE_INCLUDE_PATH` and `LIB`. The instruction [here](https://github.com/pytorch/pytorch/blob/master/docs/source/notes/windows.rst#building-from-source) is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used. Build with CUDA diff --git a/RELEASE.md b/RELEASE.md index 5ecc207a8ce67d..44369894d7c8b1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -37,7 +37,7 @@ An example of this would look like: release/1.8 ``` -Please make sure to create branch that pins divergent point of release branch from the main thunk, i.e. `orig/release/{MAJOR}.{MINOR}` +Please make sure to create branch that pins divergent point of release branch from the main branch, i.e. `orig/release/{MAJOR}.{MINOR}` ### Making release branch specific changes These are examples of changes that should be made to release branches so that CI / tooling can function normally on diff --git a/android/README.md b/android/README.md index f4a0b193c07c04..cea8cc0c2ff1a0 100644 --- a/android/README.md +++ b/android/README.md @@ -95,13 +95,12 @@ dependencies { implementation(name:'pytorch_android', ext:'aar') implementation(name:'pytorch_android_torchvision', ext:'aar') ... - implementation 'com.android.support:appcompat-v7:28.0.0' implementation 'com.facebook.soloader:nativeloader:0.8.0' implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3' } ``` We also have to add all transitive dependencies of our aars. -As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L62-L63) on `'com.android.support:appcompat-v7:28.0.0'`, `'com.facebook.soloader:nativeloader:0.8.0'` and 'com.facebook.fbjni:fbjni-java-only:0.0.3', we need to add them. +As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.8.0'` and `'com.facebook.fbjni:fbjni-java-only:0.0.3'`, we need to add them. (In case of using maven dependencies they are added automatically from `pom.xml`). You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly. diff --git a/android/build.gradle b/android/build.gradle index 60cd969bf652d3..e168347a2d99d3 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -12,7 +12,6 @@ allprojects { rulesVersion = "1.2.0" junitVersion = "4.12" - androidSupportAppCompatV7Version = "28.0.0" fbjniJavaOnlyVersion = "0.0.3" soLoaderNativeLoaderVersion = "0.8.0" } diff --git a/android/pytorch_android/build.gradle b/android/pytorch_android/build.gradle index 978389cbd242b0..f9a7559e598f2d 100644 --- a/android/pytorch_android/build.gradle +++ b/android/pytorch_android/build.gradle @@ -74,7 +74,6 @@ android { dependencies { implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion - implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion testImplementation 'junit:junit:' + rootProject.junitVersion diff --git a/android/pytorch_android_torchvision/build.gradle b/android/pytorch_android_torchvision/build.gradle index 77ac8d6fbbdec0..06d8d4db264f67 100644 --- a/android/pytorch_android_torchvision/build.gradle +++ b/android/pytorch_android_torchvision/build.gradle @@ -42,7 +42,6 @@ android { dependencies { implementation project(':pytorch_android') - implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion testImplementation 'junit:junit:' + rootProject.junitVersion diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index afa51a2694afbb..d6d110b55ade9b 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -1144,10 +1144,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { BINARY_POINTWISE(mul); BINARY_POINTWISE(div); { - using Binop = Tensor (*)(const Tensor&, const Tensor&, c10::optional); - using Unop = Tensor (*)(const Tensor&, const Scalar&, c10::optional); - m.impl("div.Tensor_mode", binary_pointwise_batching_rule>); - m.impl("div.Scalar_mode", unwrap_and_call>); + using Binop = Tensor (*)(const Tensor&, const Tensor&, c10::optional); + using Unop = Tensor (*)(const Tensor&, const Scalar&, c10::optional); + m.impl("div.Tensor_mode", binary_pointwise_batching_rule>); + m.impl("div.Scalar_mode", unwrap_and_call>); } // at::pow has three out-of-place overloads diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 804e8ff971b53c..bd94b95087fa7e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -309,6 +309,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE) set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE) set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE}) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") + set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE) + endif() if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER 6.9 AND CMAKE_C_COMPILER_VERSION VERSION_LESS 8) set(GCC_7 True) diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 76a2c38244b068..4b90f59f5adab1 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -8,9 +8,9 @@ MemOverlap has_internal_overlap(const Tensor& tensor) { } MemOverlap has_internal_overlap(TensorImpl* t) { - AT_ASSERT(t->layout() == kStrided); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided); - if (t->is_contiguous()) { + if (t->is_non_overlapping_and_dense()) { return MemOverlap::NO; } @@ -45,7 +45,7 @@ MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) { if (a->numel() == 0 || b->numel() == 0) { return MemOverlapStatus::NO; } - if (!a->is_contiguous() || !b->is_contiguous()) { + if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) { return MemOverlapStatus::TOO_HARD; } if (!a->has_storage() || !b->has_storage()) { diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 122b8ea7548bdb..3da7c8155c8523 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -128,7 +128,7 @@ void launch_no_thread_state(std::function fn); TORCH_API void intraop_launch(std::function func); // Launches intra-op parallel task, returns a future -TORCH_API std::shared_ptr intraop_launch_future( +TORCH_API c10::intrusive_ptr intraop_launch_future( std::function func); // Returns number of intra-op threads used by default diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index 6aa2c64825912d..52da493c7d5759 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -271,10 +271,10 @@ void intraop_launch(std::function func) { #endif // C10_MOBILE } -std::shared_ptr intraop_launch_future( +c10::intrusive_ptr intraop_launch_future( std::function func) { #ifndef C10_MOBILE - auto future = std::make_shared(c10::NoneType::get()); + auto future = c10::make_intrusive(c10::NoneType::get()); if (!in_parallel_region() && get_num_threads() > 1) { _get_intraop_pool().run( [func, future]() { @@ -290,7 +290,7 @@ std::shared_ptr intraop_launch_future( #else // TODO: caffe2::PThreadPool only provides a data-parallel API. // Task parallelism is not currently supported. - auto future = std::make_shared(NoneType::get()); + auto future = c10::make_intrusive(NoneType::get()); func(); future->markCompleted(); return future; diff --git a/aten/src/ATen/ParallelNativeTBB.cpp b/aten/src/ATen/ParallelNativeTBB.cpp index 9def3aaa9fc2c2..7ea69f6a019b33 100644 --- a/aten/src/ATen/ParallelNativeTBB.cpp +++ b/aten/src/ATen/ParallelNativeTBB.cpp @@ -85,9 +85,9 @@ void intraop_launch(std::function func) { } } -std::shared_ptr intraop_launch_future( +c10::intrusive_ptr intraop_launch_future( std::function func) { - auto future = std::make_shared(NoneType::get()); + auto future = c10::make_intrusive(NoneType::get()); if (get_num_threads() > 1) { tg_.run( [func, future]() { diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index 48b826a00c8952..070e8373dac32d 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -101,10 +101,10 @@ void intraop_launch(std::function func) { func(); } -std::shared_ptr intraop_launch_future( +c10::intrusive_ptr intraop_launch_future( std::function func) { func(); - auto future = std::make_shared(NoneType::get()); + auto future = c10::make_intrusive(NoneType::get()); future->markCompleted(); return future; } diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index a416e5e5305186..e2fc89a9db8498 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -29,6 +29,11 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { // because many algorithms proceed by merging two sorted lists (of indices). bool coalesced_ = false; + // compute_numel with integer multiplication overflow check, see gh-57542 + void refresh_numel() { + TensorImpl::safe_refresh_numel(); + } + public: // Public for now... explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index bb1be1d97e13bd..e52a62ba723318 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -655,7 +655,7 @@ StrideVector TensorIteratorBase::get_strides() const { StrideVector strides; for (int dim = 0; dim < ndim(); dim++) { for (int arg = 0; arg < ntensors(); arg++) { - strides.push_back(operands_[arg].stride_bytes[dim]); + strides.emplace_back(operands_[arg].stride_bytes[dim]); } } return strides; @@ -670,10 +670,15 @@ void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const { strides.push_back(0); } + auto base_ptrs = get_base_ptrs(); if (ndim() <= 1) { - auto ptrs = get_data_ptrs(base_ptrs, { range.begin }); - loop(ptrs.data(), strides.data(), range.size(), 1); + if (range.begin > 0) { + auto ptrs = get_data_ptrs(base_ptrs, {range.begin}); + loop(ptrs.data(), strides.data(), range.size(), 1); + } else { + loop(base_ptrs.data(), strides.data(), range.size(), 1); + } } else { auto counter = DimCounter(shape_, range); while (!counter.is_done()) { @@ -894,13 +899,22 @@ TensorIterator TensorIterator::unary_float_op(Tensor& out, const Tensor& a) { return iter; } +#define NULLARY_OP_CONFIG() \ + TensorIteratorConfig() \ + .set_check_mem_overlap(true) \ + .check_all_same_dtype(false) \ + /* FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 */ \ + .resize_outputs(false) + TensorIterator TensorIterator::nullary_op(Tensor& out) { - return TensorIteratorConfig() - .set_check_mem_overlap(true) - .check_all_same_dtype(false) + return NULLARY_OP_CONFIG() .add_output(out) - // FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 - .resize_outputs(false) + .build(); +} + +TensorIterator TensorIterator::borrowing_nullary_op(Tensor& out) { + return NULLARY_OP_CONFIG() + .add_borrowed_output(out) .build(); } @@ -1003,7 +1017,9 @@ void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config assert_no_internal_overlap(*output); for (int j = num_outputs_; j < ntensors(); j++) { const auto& input = operands_[j].tensor; - assert_no_partial_overlap(*output, *input); + if (input->unsafeGetTensorImpl()!=output->unsafeGetTensorImpl()) { + assert_no_partial_overlap(*output, *input); + } } } } diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index 20eec492b0e305..efb87d7dc96149 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -471,6 +471,7 @@ struct TORCH_API TensorIterator final : public TensorIteratorBase { static TensorIterator unary_op(Tensor& out, const Tensor& a); static TensorIterator unary_float_op(Tensor& out, const Tensor& a); static TensorIterator nullary_op(Tensor& out); + static TensorIterator borrowing_nullary_op(Tensor& out); static TensorIterator reduce_op(Tensor& out, const Tensor& a); static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a); diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index ee5786b4355d01..0e81c969064b35 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -13,19 +13,24 @@ namespace at { namespace autocast { bool is_enabled() { - //return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA) || - // !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU); return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA); } void set_enabled(bool new_enabled) { - //c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled); c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled); } +bool is_cpu_enabled() { + return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU); +} + +void set_cpu_enabled(bool new_enabled) { + c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled); +} + namespace { // Imitate Apex and cache some of the casts to streamline parameter reuse. -// Our heuristic is to cache fp16 casts of fp32 model weights (see cached_cast below). +// Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below). // // After discussion with @ezyang, the cache uses the following structure: // The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's @@ -51,6 +56,9 @@ thread_local std::unordered_map cached_casts; // it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) thread_local int nesting = 0; + +// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU. +thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16; } void clear_cache() { @@ -65,15 +73,28 @@ int decrement_nesting() { return --nesting; } +at::ScalarType get_autocast_cpu_dtype() { + return autocast_cpu_dtype; +} + +void set_autocast_cpu_dtype(at::ScalarType dtype) { + TORCH_CHECK( + dtype == at::kBFloat16, + "Currently, AutocastCPU only support Bfloat16 as the autocast_cpu_dtype"); + autocast_cpu_dtype = dtype; +} + // Overload to catch Tensor args // TODO (possible optimization): // Move cast_cache to an inline function in a header with cached_casts declared as // extern thread_local in the header. -Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) { - if (is_eligible(arg) && (arg.scalar_type() != to_type)) { - // Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves). +Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) { + if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) { + // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves). // See cached_casts declaration above for detailed strategy. - bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view()); + bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && + arg.scalar_type() == at::kFloat && arg.requires_grad() && + arg.is_leaf() && !arg.is_view()); if (can_try_cache) { auto it = cached_casts.find(arg.unsafeGetTensorImpl()); if (it != cached_casts.end()) { @@ -94,7 +115,8 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) { // Policies correspond to op categories that need code-divergent handling. // Wrapper templates below are specialized based on a policy template parameter. enum class CastPolicy : uint8_t { - fp16 = 0, // Cast all inputs to at::kHalf before running the op. + lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before running the op. + // Currently, lower_precision_fp is fp16 for AutocastCUDA, and is defined by user(default bf16) for AutocastCPU. fp32, // Cast all inputs to at::kFloat before running the op. fp32_set_opt_dtype, // Treats functions (like softmax) that // 1. we'd like to run in fp32 and @@ -122,29 +144,29 @@ Interior WrapFunction_ specializations are defined for each CastPolicy. ********************************************************************************************************/ // Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy -template struct WrapFunction_ {}; +template struct WrapFunction_ {}; -// CastPolicy::fp16 -template -struct WrapFunction_> { +// CastPolicy::lower_precision_fp General_DeviceType +template +struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); - return (*F)(cached_cast(at::kHalf, args)...); + c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast(get_lower_precision_fp_from_device_type(device_type), args, device_type)...); } }; -// CastPolicy::fp32 -template -struct WrapFunction_> { +// CastPolicy::fp32 General_DeviceType +template +struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); - return (*F)(cached_cast(at::kFloat, args)...); + c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast(at::kFloat, args, device_type)...); } }; -// CastPolicy::fp32_set_opt_dtype +// CastPolicy::fp32_set_opt_dtype DeviceType::CUDA template -struct WrapFunction_> { +struct WrapFunction_> { static Ret call(Args... args) { c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); if (firstarg_is_eligible(args...)) { @@ -157,9 +179,9 @@ struct WrapFunction_ -struct WrapFunction_> { +struct WrapFunction_> { static Ret call(Args... args) { c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); at::ScalarType out_type = type_from_firstarg(at::kFloat, args...); @@ -167,18 +189,19 @@ struct WrapFunction_ -struct WrapFunction_> { +// CastPolicy::promote General_DeviceType +template +struct WrapFunction_> { static Ret call(Args... args) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); - auto to_type = promote_type(at::kHalf, args...); - return (*F)(cached_cast(to_type, args)...); + c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); + auto to_type = promote_type(get_lower_precision_fp_from_device_type(device_type), device_type, args...); + return (*F)(cached_cast(to_type, args, device_type)...); } }; // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h) template // The actual function we're redispatching to. struct WrapFunction final { using type = WrapFunction_::return_type, @@ -213,14 +237,15 @@ namespace { This section performs load-time registration for autocast wrappers. It's debatable at what level operations should be patched. We'd like casts to be autograd-exposed -and precede autograd history recording, so that for fp16 ops, input tensors are saved for backward -in fp16 rather than fp32. Saving inputs in fp16 can significantly reduce a model's memory footprint. +and precede autograd history recording, so that for lower_precision_fp ops, input tensors are saved for backward +in lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp can significantly reduce +a model's memory footprint. Option 1 (strawman): Patch only at the level of explicit calls into cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are guaranteed to use Tensor Cores, therefore they're the ones that -will benefit most from fp16. Potential pitfall: convolutions (and other ops) are wrapped in several +will benefit most from lower_precision_fp. Potential pitfall: convolutions (and other ops) are wrapped in several layers of at::* calls. If one of those happens to record autograd history, then we've lost the -opportunity to save inputs in fp16. +opportunity to save inputs in lower_precision_fp. Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd history recording can't sneak in ahead of autocast. This mirrors Apex most closely. @@ -242,12 +267,17 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin // (that's why SIGNATURE is repeated in the WrapFunction instantiation) #define KERNEL(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \ m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ - &WrapFunction::type::call); + &WrapFunction::type::call); // Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype) #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \ m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ - &WrapFunction::type::call); + &WrapFunction::type::call); + +// KERNEL_CPU registration for AutocastCPU +#define KERNEL_CPU(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ + &WrapFunction::type::call); /***************************************** Explicit registration for out-of-place ops @@ -257,65 +287,65 @@ TORCH_LIBRARY_IMPL(_, Autocast, m) { } TORCH_LIBRARY_IMPL(aten, Autocast, m) { - // fp16 - KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16) - KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16) - KERNEL(ADD_NS(_convolution_nogroup), "_convolution_nogroup", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef), fp16) - KERNEL(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16) - KERNEL(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16) - KERNEL(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16) - KERNEL(ADD_NS(conv_tbc), "conv_tbc", Tensor (const Tensor &, const Tensor &, const Tensor &, int64_t), fp16) - KERNEL(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp16) - KERNEL(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp16) - KERNEL(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp16) - KERNEL(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp16) - KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16) - KERNEL(ADD_NS(prelu), "prelu", Tensor (const Tensor &, const Tensor &), fp16) - KERNEL(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16) - KERNEL(ADD_NS(addmv), "addmv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16) - KERNEL(ADD_NS(addr), "addr", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16) - KERNEL(ADD_NS(matmul), "matmul", Tensor (const Tensor &, const Tensor &), fp16) - KERNEL(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), fp16) - KERNEL(ADD_NS(mv), "mv", Tensor (const Tensor &, const Tensor &), fp16) - KERNEL(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional&), fp16) - KERNEL(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16) - KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16) - KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), fp16) - KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), fp16) - KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), fp16) + // lower_precision_fp + KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(_convolution_nogroup), "_convolution_nogroup", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef), lower_precision_fp) + KERNEL(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL(ADD_NS(conv_tbc), "conv_tbc", Tensor (const Tensor &, const Tensor &, const Tensor &, int64_t), lower_precision_fp) + KERNEL(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), lower_precision_fp) + KERNEL(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), lower_precision_fp) + KERNEL(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), lower_precision_fp) + KERNEL(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), lower_precision_fp) + KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp) + KERNEL(ADD_NS(prelu), "prelu", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL(ADD_NS(addmv), "addmv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL(ADD_NS(addr), "addr", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL(ADD_NS(matmul), "matmul", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL(ADD_NS(mv), "mv", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional&), lower_precision_fp) + KERNEL(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), lower_precision_fp) + KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), lower_precision_fp) // The macro doesn't like these (I think it chokes on commas inside <>) so write them manually m.impl(TORCH_SELECTIVE_NAME("aten::_thnn_fused_lstm_cell"), - TORCH_FN((&WrapFunction (const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), std::tuple (const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), &ADD_NS(_thnn_fused_lstm_cell)>::type::call))); m.impl("_thnn_fused_gru_cell", - TORCH_FN((&WrapFunction (const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), std::tuple (const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), &ADD_NS(_thnn_fused_gru_cell)>::type::call))); m.impl("lstm_cell", - TORCH_FN((&WrapFunction (const Tensor &, TensorList, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), std::tuple (const Tensor &, TensorList, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), &ADD_NS(lstm_cell)>::type::call))); m.impl("gru_cell", - TORCH_FN((&WrapFunction&, const c10::optional&), Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), &ADD_NS(gru_cell)>::type::call))); m.impl("rnn_tanh_cell", // tanh unary op is executed as a cuda math library call. - TORCH_FN((&WrapFunction&, const c10::optional&), Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), &ADD_NS(rnn_tanh_cell)>::type::call))); m.impl("rnn_relu_cell", - TORCH_FN((&WrapFunction&, const c10::optional&), Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional&, const c10::optional&), &ADD_NS(rnn_relu_cell)>::type::call))); @@ -342,7 +372,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(layer_norm), "layer_norm", Tensor (const Tensor &, IntArrayRef, const c10::optional&, const c10::optional&, double, bool), fp32) // The macro doesn't like this one (I think it chokes on commas inside <>) so write it manually m.impl(TORCH_SELECTIVE_NAME("aten::native_layer_norm"), - TORCH_FN((&WrapFunction (const Tensor&, IntArrayRef, const c10::optional&, const c10::optional&, double), std::tuple (const Tensor&, IntArrayRef, const c10::optional&, const c10::optional&, double), &ADD_NS(native_layer_norm)>::type::call))); @@ -372,6 +402,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(pdist), "pdist", Tensor (const Tensor &, double), fp32) KERNEL(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32) KERNEL(ADD_NS(renorm), "renorm", Tensor (const Tensor &, const Scalar&, int64_t, const Scalar&), fp32) + KERNEL(ADD_NS(grid_sampler), "grid_sampler", Tensor (const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32) // fp32_set_opt_dtype KERNEL(ADD_NS(prod), "prod", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(prod), "prod.dim_int", Tensor (const Tensor &, int64_t, bool, c10::optional), fp32_set_opt_dtype) @@ -418,7 +449,86 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } +TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { + m.fallback(torch::CppFunction::makeFallthrough()); } +TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { + // lower_precision_fp cast policy + KERNEL_CPU(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL_CPU(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL_CPU(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL_CPU(ADD_NS(_log_softmax), "_log_softmax", Tensor (const Tensor &, int64_t, bool), lower_precision_fp) + KERNEL_CPU(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL_CPU(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) + KERNEL_CPU(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL_CPU(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL_CPU(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) + KERNEL_CPU(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional &), lower_precision_fp) + + // fp32 cast policy + KERNEL_CPU(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(batch_norm), "batch_norm", Tensor (const Tensor &, const c10::optional &, const c10::optional &, const c10::optional &, const c10::optional &, bool, double, double, bool), fp32) + KERNEL_CPU(ADD_NS(max_pool2d), "max_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32) + KERNEL_CPU(ADD_NS(adaptive_avg_pool2d), "adaptive_avg_pool2d", Tensor (const Tensor &, IntArrayRef), fp32) + + KERNEL_CPU(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp32) + KERNEL_CPU(ADD_NS(dropout), "dropout", Tensor (const Tensor &, double, bool), fp32) + KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32) + KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional), fp32) + KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(upsample_nearest2d), "upsample_nearest2d", Tensor (const Tensor &, IntArrayRef, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(upsample_nearest2d), "upsample_nearest2d.vec", Tensor (const Tensor &, c10::optional, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(upsample_nearest3d), "upsample_nearest3d", Tensor (const Tensor &, IntArrayRef, c10::optional, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(upsample_nearest3d), "upsample_nearest3d.vec", Tensor (const Tensor &, c10::optional, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(upsample_linear1d), "upsample_linear1d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(upsample_linear1d), "upsample_linear1d.vec", Tensor (const Tensor &, c10::optional, bool, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d.vec", Tensor (const Tensor &, c10::optional, bool, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d.vec", Tensor (const Tensor &, c10::optional, bool, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(binary_cross_entropy), "binary_cross_entropy", Tensor (const Tensor &, const Tensor &, const c10::optional&, int64_t), fp32) + KERNEL_CPU(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional&, const c10::optional&, int64_t), fp32) + KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Scalar", Tensor (const Tensor &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Tensor", Tensor (const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(pow), "pow.Scalar", Tensor (const Scalar&, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32) + KERNEL_CPU(ADD_NS(reflection_pad1d), "reflection_pad1d", Tensor (const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(std), "std", Tensor (const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(std), "std.dim", Tensor (const Tensor &, IntArrayRef, bool, bool), fp32) + KERNEL_CPU(ADD_NS(instance_norm), "instance_norm", Tensor (const Tensor &, const c10::optional&, const c10::optional&, const c10::optional&, const c10::optional&, bool, double, double, bool), fp32) + KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32) + + // promote + KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote) + KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) + + m.impl(TORCH_SELECTIVE_NAME("aten::topk"), + TORCH_FN((&WrapFunction (const Tensor &, int64_t, int64_t, bool, bool), + std::tuple (const Tensor &, int64_t, int64_t, bool, bool), + &ADD_NS(topk)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::sort"), + TORCH_FN((&WrapFunction (const Tensor &, int64_t, bool), + std::tuple (const Tensor &, int64_t, bool), + &ADD_NS(sort)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue"), + TORCH_FN((&WrapFunction (const Tensor &, int64_t, int64_t, bool), + std::tuple (const Tensor &, int64_t, int64_t, bool), + &ADD_NS(kthvalue)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue.dimname"), + TORCH_FN((&WrapFunction (const Tensor &, int64_t, at::Dimname, bool), + std::tuple (const Tensor &, int64_t, at::Dimname, bool), + &ADD_NS(kthvalue)>::type::call))); +} +} // namespace } // namespace autocast } // namespace at diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 85db1c2e1a45da..e78a9c8b54ca28 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -3,17 +3,49 @@ namespace at { namespace autocast { -namespace { - bool is_autocast_eligible(const Tensor& tensor) { - return (tensor.is_cuda() || tensor.is_xla()) && tensor.is_floating_point(); - } -} // namespace - TORCH_API bool is_enabled(); TORCH_API void set_enabled(bool enabled); TORCH_API void clear_cache(); TORCH_API int increment_nesting(); TORCH_API int decrement_nesting(); +TORCH_API bool is_cpu_enabled(); +TORCH_API void set_cpu_enabled(bool enabled); +TORCH_API at::ScalarType get_autocast_cpu_dtype(); +TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype); + +namespace { + bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) { + return device_type == DeviceType::CUDA + ? (tensor.is_cuda() || tensor.is_xla()) && tensor.is_floating_point() + : (tensor.is_cpu() || tensor.is_mkldnn()) && tensor.is_floating_point(); + } +} // namespace + +inline DispatchKey get_autocast_dispatch_key_from_device_type( + DeviceType device_type) { + switch (device_type) { + case DeviceType::CUDA: + return DispatchKey::Autocast; + case DeviceType::CPU: + return DispatchKey::AutocastCPU; + default: + throw std::runtime_error( + "unknown device type for autocast in get_autocast_dispatch_key_from_device_type"); + } +} + +inline at::ScalarType get_lower_precision_fp_from_device_type( + DeviceType device_type) { + switch (device_type) { + case DeviceType::CUDA: + return at::kHalf; + case DeviceType::CPU: + return get_autocast_cpu_dtype(); + default: + throw std::runtime_error( + "unknown device type for autocast in get_lower_precision_fp_from_device_type"); + } +} /******************************************************************** Logic to extract the promote type from any Tensor or TensorList args. @@ -22,19 +54,24 @@ Logic to extract the promote type from any Tensor or TensorList args. // Overload to catch Tensor args. // If nextArg is floating-point, compare its scalar_type with our // current best guess for the promote type, and update if necessary. -inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) { +inline at::ScalarType prioritize( + at::ScalarType current, + const Tensor& nextArg, + DeviceType device_type=DeviceType::CUDA) { if (current == at::kDouble) { AT_ERROR("promote type is double in at::autocast::prioritize"); return current; } - if (is_autocast_eligible(nextArg)) { + at::ScalarType lower_precision_fp = + get_lower_precision_fp_from_device_type(device_type); + if (is_autocast_eligible(nextArg, device_type)) { auto next = nextArg.scalar_type(); if (next == at::kDouble) { return current; // ignores double tensors } else if (current == at::kFloat || next == at::kFloat) { - return at::kFloat; // prioritizes float over half - } else if (current == at::kHalf && next == at::kHalf) { - return at::kHalf; + return at::kFloat; // prioritizes float over lower_precision_fp + } else if (current == lower_precision_fp && next == lower_precision_fp) { + return lower_precision_fp; } else { AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize"); return current; @@ -46,64 +83,92 @@ inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) // Overload to catch TensorList args (for e.g. cat, stack). // Reuses the overload above to process each Tensor in the list. -inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) { +inline at::ScalarType prioritize( + at::ScalarType current, + const TensorList& list, + DeviceType device_type=DeviceType::CUDA) { for (const auto& tensor : list) { - current = prioritize(current, tensor); + current = prioritize(current, tensor, device_type); } return current; } // Template to catch non-Tensor args (no-op that returns current best guess) template -inline at::ScalarType prioritize(at::ScalarType current, T nextArg) { +inline at::ScalarType prioritize( + at::ScalarType current, + T nextArg, + DeviceType device_type=DeviceType::CUDA) { return current; } // Overload for the tail case. -inline at::ScalarType promote_type(at::ScalarType current) { +inline at::ScalarType promote_type( + at::ScalarType current, + DeviceType device_type) { return current; } -// Unpack args and determine if incoming float16 tensors need to be promoted to float32. +// Unpack args and determine if incoming lower_precision_fp tensors need to be promoted to float32. // Non-Tensor arguments are ignored. template -inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) { - auto new_current = prioritize(current, arg0); - return promote_type(new_current, args...); +inline at::ScalarType promote_type( + at::ScalarType current, + DeviceType device_type, + Arg0 arg0, + Args... args) { + auto new_current = prioritize(current, arg0, device_type); + return promote_type(new_current, device_type, args...); } /**************************************************** Logic to apply cached casting to any Tensor argument. ****************************************************/ -inline bool is_eligible(const Tensor& arg) { - return (arg.defined() && is_autocast_eligible(arg) && (arg.scalar_type() != at::kDouble)); +inline bool is_eligible( + const Tensor& arg, + DeviceType device_type=DeviceType::CUDA) { + return (arg.defined() && + is_autocast_eligible(arg, device_type) && + (arg.scalar_type() != at::kDouble)); } // Overload to catch Tensor args -TORCH_API Tensor cached_cast(at::ScalarType to_type, const Tensor& arg); +TORCH_API Tensor cached_cast( + at::ScalarType to_type, + const Tensor& arg, + DeviceType device_type=DeviceType::CUDA); // Overload to process optional -inline c10::optional cached_cast(at::ScalarType to_type, const c10::optional& arg) { +inline c10::optional cached_cast( + at::ScalarType to_type, + const c10::optional& arg, + DeviceType device_type=DeviceType::CUDA) { if (arg.has_value()) { - return cached_cast(to_type, *arg); + return cached_cast(to_type, *arg, device_type); } else { return c10::nullopt; } } // Overload to process TensorLists -inline std::vector cached_cast(at::ScalarType to_type, const TensorList& arg) { +inline std::vector cached_cast( + at::ScalarType to_type, + const TensorList& arg, + DeviceType device_type=DeviceType::CUDA) { std::vector vec; vec.reserve(arg.size()); for (const auto& t : arg) { - vec.push_back(cached_cast(to_type, t)); + vec.push_back(cached_cast(to_type, t, device_type)); } return vec; } // Template to catch non-Tensor args. template -inline T cached_cast(at::ScalarType to_type, T arg) { +inline T cached_cast( + at::ScalarType to_type, + T arg, + DeviceType device_type=DeviceType::CUDA) { return arg; } diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h index cf04a9d950ed9b..9e743558e5a1b4 100644 --- a/aten/src/ATen/core/Dict_inl.h +++ b/aten/src/ATen/core/Dict_inl.h @@ -41,7 +41,7 @@ inline size_t DictKeyHash::operator()(const IValue& ivalue) const { if (ivalue.isInt()) { return std::hash()(ivalue.toInt()); } else if (ivalue.isString()) { - return std::hash()(ivalue.toStringRef()); + return std::hash()(ivalue.toStringView()); } else if (ivalue.isDouble()) { return std::hash()(ivalue.toDouble()); } else if (ivalue.isComplexDouble()) { diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index af1452ecddd8ad..d2251c06a4f4ed 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -495,6 +495,7 @@ _(aten, miopen_depthwise_convolution_backward_input) \ _(aten, miopen_depthwise_convolution_backward_weight) \ _(aten, miopen_rnn) \ _(aten, miopen_rnn_backward) \ +_(aten, mish) \ _(aten, mkldnn_convolution) \ _(aten, mkldnn_convolution_backward) \ _(aten, mkldnn_convolution_backward_input) \ diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 04286d28c6edd4..bb01559265fbeb 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -91,7 +91,7 @@ namespace impl { int64_t, double, bool, - std::string, + c10::string_view, at::Tensor, at::Scalar, c10::QScheme, @@ -199,7 +199,7 @@ namespace impl { template struct assert_is_valid_input_type::value>> { static_assert(guts::false_t::value, - "You tried to register a kernel with an unsupported input type: const char*. Please use std::string instead."); + "You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead."); }; template struct assert_is_valid_input_type, T>::value>> { @@ -287,7 +287,7 @@ namespace impl { template struct assert_is_valid_output_type::value>> { static_assert(guts::false_t::value, - "You tried to register a kernel with an unsupported output type: const char*. Please use std::string instead."); + "You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead."); }; template struct assert_is_valid_output_type, T>::value>> { diff --git a/aten/src/ATen/core/custom_class.cpp b/aten/src/ATen/core/custom_class.cpp index c396e810eabe15..8f1a66452576b4 100644 --- a/aten/src/ATen/core/custom_class.cpp +++ b/aten/src/ATen/core/custom_class.cpp @@ -50,5 +50,51 @@ std::vector customClassSchemasForBCCheck() { }); } +namespace detail { +class_base::class_base( + const std::string& namespaceName, + const std::string& className, + std::string doc_string, + const std::type_info& intrusivePtrClassTypeid, + const std::type_info& taggedCapsuleClassTypeid) + : qualClassName("__torch__.torch.classes." + namespaceName + '.' + className), + classTypePtr(at::ClassType::create( + c10::QualifiedName(qualClassName), + std::weak_ptr(), + /*is_module=*/false, + std::move(doc_string))) +{ + detail::checkValidIdent(namespaceName, "Namespace name"); + detail::checkValidIdent(className, "Class name"); + classTypePtr->addAttribute("capsule", at::CapsuleType::get()); + c10::getCustomClassTypeMap().insert( + {std::type_index(intrusivePtrClassTypeid), classTypePtr}); + c10::getCustomClassTypeMap().insert( + {std::type_index(taggedCapsuleClassTypeid), classTypePtr}); + registerCustomClass(classTypePtr); +} + +c10::FunctionSchema class_base::withNewArguments( + const c10::FunctionSchema& schema, + std::initializer_list default_args) { + const auto& old_args = schema.arguments(); + std::vector new_args; + new_args.reserve(old_args.size()); + + new_args.emplace_back(old_args[0]); + // Skip self. + size_t argIdx = 1; + for (const auto& default_arg : default_args) { + auto& old_arg = old_args[argIdx++]; + new_args.emplace_back( + default_arg.name_, + old_arg.type(), + old_arg.N(), + default_arg.value_); + } + return schema.cloneWithArguments(std::move(new_args)); +} + +} // namespace detail } // namespace torch diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 5d963098117b6f..c3156628007b62 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -39,6 +39,16 @@ TORCH_API c10::intrusive_ptr ConstantString::create( return c10::make_intrusive(std::move(str_)); } +TORCH_API c10::intrusive_ptr ConstantString::create( + c10::string_view str_) { + return c10::make_intrusive(std::string(str_)); +} + +TORCH_API c10::intrusive_ptr ConstantString::create( + const char* str_) { + return c10::make_intrusive(std::string(str_)); +} + bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) { return lhs.elements_.size() == rhs.elements_.size() && // see [container equality] @@ -616,7 +626,7 @@ IValueComparator getLessThanComparator(const IValue& v) { if (v.isString()) { return [](const IValue& a, const IValue& b) { - return a.toString()->string() < b.toString()->string(); + return a.toStringRef() < b.toStringRef(); }; } diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 57ac050a96e966..ec699fa44af033 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -555,6 +555,7 @@ struct TORCH_API IValue final { IValue(c10::intrusive_ptr v); IValue(std::string v); IValue(const char* v) : IValue(std::string(v)) {} + IValue(c10::string_view v) : IValue(std::string(v)) {}; bool isString() const { return Tag::String == tag; } @@ -563,6 +564,7 @@ struct TORCH_API IValue final { const std::string& toStringRef() const; c10::optional> toOptionalStringRef() const; + c10::string_view toStringView() const; // DoubleList bool isDoubleList() const; diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index ed09e4b8d2667c..10d0a7a985577d 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -227,10 +227,18 @@ struct TORCH_API ConstantString final : c10::intrusive_ptr_target { public: ConstantString(std::string str) : str_(std::move(str)) {} + ConstantString(c10::string_view str) : str_(std::string(str)) {} static c10::intrusive_ptr create(std::string str_); + static c10::intrusive_ptr create(c10::string_view str_); + static c10::intrusive_ptr create(const char* str_); + const std::string& string() const { return str_; } + c10::string_view string_view() const { + return str_; + } + operator const std::string&() const { return string(); } @@ -306,12 +314,22 @@ struct EnumHolder; // Future struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { - public: + private: + // Keep this private in order to force users to go through make_intrusive and + // thus prevent creating a Future that's not held by an intrusive_ptr. explicit Future(TypePtr type, std::vector devices={}) : type_(std::move(type)), impl_(getTypeOfDevices(devices)), devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {} + friend c10::intrusive_ptr; + + public: + Future(const Future&) = delete; + Future(Future&&) = delete; + Future& operator=(const Future&) = delete; + Future& operator=(Future&&) = delete; + struct TORCH_API FutureError final : public std::exception { explicit FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} @@ -477,7 +495,13 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { * If the future has already completed, * this function will execute the callback immediately. */ - void addCallback(std::function callback) { + template + void addCallback(T callback) { +#if __cpp_lib_is_invocable >= 201703 + static_assert( + std::is_invocable_r::value, + "The callback must have signature void(Future&)"); +#endif std::unique_lock lock(mutex_); if (completed()) { lock.unlock(); @@ -492,12 +516,16 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { * value of the callback. This is necessary when the callback provider needs * to know for sure when the callback has finished. */ - c10::intrusive_ptr then( - std::function callback, - TypePtr type) { + template + c10::intrusive_ptr then(T callback, TypePtr type) { +#if __cpp_lib_is_invocable >= 201703 + static_assert( + std::is_invocable_r::value, + "The callback must have signature IValue(Future&)"); +#endif auto childFut = createInstance(std::move(type)); addCallback( - [childFut, cb = std::move(callback)](Future& parentFut) { + [childFut, cb = std::move(callback)](Future& parentFut) mutable { try { childFut->markCompleted(cb(parentFut)); } catch (std::exception&) { @@ -507,6 +535,36 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { return childFut; } + template + c10::intrusive_ptr thenAsync(T callback, TypePtr type) { +#if __cpp_lib_is_invocable >= 201703 + static_assert( + std::is_invocable_r, T, Future&>::value, + "The callback must have signature c10::intrusive_ptr(Future&)"); +#endif + auto childFut = createInstance(std::move(type)); + addCallback( + [childFut, cb = std::move(callback)](Future& parentFut) mutable { + c10::intrusive_ptr intermediateFut; + try { + intermediateFut = cb(parentFut); + } catch (std::exception&) { + childFut->setError(std::current_exception()); + return; + } + intermediateFut->addCallback( + [childFut = std::move(childFut)](Future& intermediateFut) { + if (intermediateFut.hasError()) { + childFut->setError(intermediateFut.exception_ptr()); + } else { + childFut->markCompleted( + intermediateFut.value(), intermediateFut.dataPtrs()); + } + }); + }); + return childFut; + } + // Tries to retrieve the error message from std::exception_ptr. std::string tryRetrieveErrorMessage() const { TORCH_CHECK(hasError(), "No error present on the future."); @@ -558,7 +616,14 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { // how/when that happens) as it will ensure that the proper "environment" is // set up before running the callback, as in, it will set up the CUDA streams, // synchronize them with the value, and so on (if needed). - void invokeCallback(std::function callback) { + template + void invokeCallback(T callback) { +#if __cpp_lib_is_invocable >= 201703 + static_assert( + std::is_invocable_r::value, + "The callback must have signature void(Future&)"); +#endif + c10::OptionalDeviceGuard deviceGuard(currentDevice_); std::vector streams; @@ -989,6 +1054,7 @@ DEFINE_TO(c10::impl::GenericList, toList) DEFINE_TO(c10::impl::GenericDict, toGenericDict) DEFINE_TO(c10::intrusive_ptr, toTuple) DEFINE_TO(std::string, toStringRef) +DEFINE_TO(c10::string_view, toStringView) DEFINE_TO(c10::intrusive_ptr, toFuture) DEFINE_TO(c10::intrusive_ptr, toRRef) DEFINE_TO(c10::intrusive_ptr, toQuantizer) @@ -1198,6 +1264,14 @@ inline T IValue::to() && { return generic_to(std::move(*this), _fake_type{}); } +template <> +inline c10::optional IValue::to() && { + // In the default implementation, the IValue is destroyed with std::move. + // But if the unboxed type is optional we cannot destroy + // the IValue. + return generic_to(*this, _fake_type>{}); +} + template inline typename c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { return generic_to(*this, _fake_type{}); @@ -1495,6 +1569,16 @@ inline c10::optional> IValue:: ->string()); } +inline c10::string_view IValue::toStringView() const { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringView on null intrusive_ptr IValue"); + return static_cast( + payload.u.as_intrusive_ptr) + ->string_view(); +} + inline PyObject* IValue::toPyObject() const { return toPyObjectHolder()->getPyObject(); } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index d0a4f6295e98f4..d20231d19530ca 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -244,6 +244,10 @@ struct TORCH_API ShapeSymbol { return value_; }; + int64_t value() const { + return value_; + }; + static ShapeSymbol newSymbol() { return fromStaticSize(-static_cast(++num_symbols)); }; @@ -300,6 +304,8 @@ struct TORCH_API SymbolicShape { dims_ = shape_symbols; } + void dump() const; + SymbolicShape(std::vector dims) : dims_(std::move(dims)) {} SymbolicShape(c10::IntArrayRef dims) { @@ -318,6 +324,13 @@ struct TORCH_API SymbolicShape { return (*dims_).at(i); } + ShapeSymbol at(size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + // Returns rank or nullopt in case of unranked shape. c10::optional rank() const { if(!dims_) { @@ -946,6 +959,9 @@ struct TORCH_API TupleType : public NamedType { c10::nullopt, nullptr)); // NOLINT(modernize-make-shared) } + static TupleTypePtr create() { + return create({}); + } at::ArrayRef elements() const { return elements_; @@ -1482,6 +1498,7 @@ enum class TypeVerbosity { Type, TypeAndStride, Full, + Symbolic, Default = Full, }; @@ -1677,6 +1694,12 @@ struct getTypePtr_ final { } }; template <> +struct getTypePtr_ final { + static TypePtr call() { + return StringType::get(); + } +}; +template <> struct getTypePtr_ final { static TypePtr call() { return StringType::get(); diff --git a/aten/src/ATen/core/op_registration/README.md b/aten/src/ATen/core/op_registration/README.md index ffbc5f3b3da67e..edd9f911cd0e1a 100644 --- a/aten/src/ATen/core/op_registration/README.md +++ b/aten/src/ATen/core/op_registration/README.md @@ -174,13 +174,13 @@ The kernel function can take any of the following types as inputs or outputs: * `double` (note: `float` is not supported) * `int64_t` (note: other integer types like `int`, `uint64_t`, `int32_t`, `...` are not supported) * `bool` -* `std::string` +* `c10::string_view` * `at::Scalar` (this is a type that can hold either an integer or a floating point value) * `at::optional` with T being any type from the list above The kernel function can take and return list inputs by using `torch::List`. `T` must be one of the supported types from above excluding `at::Scalar`. -The kernel function can take and return dicts by using `torch::Dict`. `Key` must be `int64_t`, `std::string`, `double` or `bool`, and `Value` must be from the list of supported types above excluding `at::Scalar`. +The kernel function can take and return dicts by using `torch::Dict`. `Key` must be `int64_t`, `c10::string_view`, `double` or `bool`, and `Value` must be from the list of supported types above excluding `at::Scalar`. When taken as input, any of these types can be taken by value (i.e. `Tensor`) or by const-reference (i.e. `const Tensor&`). We recommend taking all arguments by value, even Tensors. They will be moved in, so there is no performance overhead. diff --git a/aten/src/ATen/core/op_registration/op_allowlist.h b/aten/src/ATen/core/op_registration/op_allowlist.h index 629a307823e3b8..9f96e9d0c23824 100644 --- a/aten/src/ATen/core/op_registration/op_allowlist.h +++ b/aten/src/ATen/core/op_registration/op_allowlist.h @@ -60,6 +60,10 @@ constexpr bool op_allowlist_contains(string_view allowlist, string_view item) { // and should be registered constexpr bool op_allowlist_check(string_view op_name) { assert(op_name.find("::") != string_view::npos); + // Use assert() instead of throw() due to a gcc bug. See: + // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function + // https://github.com/fmtlib/fmt/issues/682 + assert(op_name.find("(") == string_view::npos); #if !defined(TORCH_OPERATOR_WHITELIST) // If the TORCH_OPERATOR_WHITELIST parameter is not defined, // all ops are to be registered diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index ac82a9c684609d..50a39d46e20fbf 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -33,12 +33,15 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { out << "("; size_t i = 0; + bool symbolic = type_verbosity() == TypeVerbosity::Symbolic; for (i = 0; i < *ndim; ++i) { if (i > 0) { out << ", "; } if (auto s = value->sizes()[i]) { out << *s; + } else if (symbolic) { + out << value->symbolic_sizes().at(i); } else { out << "*"; } @@ -710,7 +713,11 @@ std::ostream& operator<<( } std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) { - os << "SS(" << s.value_ << ')'; + if (s.value_ >= 0) { + os << s.value_; + } else { + os << "SS(" << s.value_ << ')'; + } return os; } @@ -1835,6 +1842,10 @@ SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const { return SymbolicShape(std::move(dims)); } +void SymbolicShape::dump() const { + std::cout << *this << "\n"; +} + bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { return rhs->kind() == TypeKind::AnyType || rhs->kind() == TypeKind::AnyEnumType || *this == *rhs; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h index 18165a073340a2..97822a83de686a 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -141,7 +141,8 @@ class Vectorized { vec_sel(a._vec0, b._vec0, mask._vecb0), vec_sel(a._vec1, b._vec1, mask._vecb1)}; } - static Vectorized arange(double base = 0., double step = 1.) { + template + static Vectorized arange(double base = 0., step_t step = static_cast(1)) { return Vectorized(base, base + step, base + 2 * step, base + 3 * step); } diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index 3fbaee69b8ba91..47c8931a2c5b41 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -135,7 +135,8 @@ class Vectorized { vec_sel(a._vec1, b._vec1, mask._vecb1)}; } - static Vectorized arange(float base = 0.f, float step = 1.f) { + template + static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { return Vectorized( base, base + step, diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index 420a0fe23da2b0..5d50ac3d6e8ba8 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -201,7 +201,8 @@ class Vectorized { vec_sel(a._vec1, b._vec1, mask._vecb1)}; } - static Vectorized arange(int16_t base = 0, int16_t step = 1) { + template + static Vectorized arange(int16_t base = 0, step_t step = static_cast(1)) { return Vectorized( base, base + step, diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index ade4be2fb99833..370bc89b240580 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -154,7 +154,8 @@ class Vectorized { vec_sel(a._vec1, b._vec1, mask._vecb1)}; } - static Vectorized arange(int32_t base = 0.f, int32_t step = 1.f) { + template + static Vectorized arange(int32_t base = 0.f, step_t step = static_cast(1)) { return Vectorized( base, base + step, diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index 395997d5caa2eb..945319fc0dc108 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -116,7 +116,8 @@ class Vectorized { vec_sel(a._vec0, b._vec0, mask._vecb0), vec_sel(a._vec1, b._vec1, mask._vecb1)}; } - static Vectorized arange(int64_t base = 0., int64_t step = 1.) { + template + static Vectorized arange(int64_t base = 0., step_t step = static_cast(1)) { return Vectorized(base, base + step, base + 2 * step, base + 3 * step); } diff --git a/aten/src/ATen/detail/CUDAHooksInterface.cpp b/aten/src/ATen/detail/CUDAHooksInterface.cpp index 3ada97350374a4..2a4dcc23aebd12 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.cpp +++ b/aten/src/ATen/detail/CUDAHooksInterface.cpp @@ -35,6 +35,7 @@ const CUDAHooksInterface& getCUDAHooks() { // for an example where we relax this restriction (but if you try to avoid // needing a lock, be careful; it doesn't look like Registry.h is thread // safe...) +#if !defined C10_MOBILE static std::once_flag once; std::call_once(once, [] { cuda_hooks = CUDAHooksRegistry()->Create("CUDAHooks", CUDAHooksArgs{}).release(); @@ -42,6 +43,11 @@ const CUDAHooksInterface& getCUDAHooks() { cuda_hooks = new CUDAHooksInterface(); } }); +#else + if (cuda_hooks == nullptr) { + cuda_hooks = new CUDAHooksInterface(); + } +#endif return *cuda_hooks; } } // namespace detail diff --git a/aten/src/ATen/detail/HIPHooksInterface.cpp b/aten/src/ATen/detail/HIPHooksInterface.cpp index 02a67667d661af..3ef47b3f16b1bf 100644 --- a/aten/src/ATen/detail/HIPHooksInterface.cpp +++ b/aten/src/ATen/detail/HIPHooksInterface.cpp @@ -12,6 +12,7 @@ namespace detail { // See getCUDAHooks for some more commentary const HIPHooksInterface& getHIPHooks() { static std::unique_ptr hip_hooks; +#if !defined C10_MOBILE static std::once_flag once; std::call_once(once, [] { hip_hooks = HIPHooksRegistry()->Create("HIPHooks", HIPHooksArgs{}); @@ -21,6 +22,13 @@ const HIPHooksInterface& getHIPHooks() { std::unique_ptr(new HIPHooksInterface()); } }); +#else + if (hip_hooks == nullptr) { + hip_hooks = + // NOLINTNEXTLINE(modernize-make-unique) + std::unique_ptr(new HIPHooksInterface()); + } +#endif return *hip_hooks; } } // namespace detail diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 266f55d53f3d5c..eeb563d34ff732 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -21,9 +21,9 @@ TORCH_META_FUNC(threshold)(const Tensor& self, const Scalar& threshold, const Sc const Tensor& result = maybe_get_output(); build(TensorIteratorConfig() .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay - .add_output(result) - .add_input(self) - .add_input(self) // other + .add_borrowed_output(result) + .add_borrowed_input(self) + .add_borrowed_input(self) // other .allow_cpu_scalars(true) .promote_inputs_to_common_dtype(true) .cast_common_dtype_to_outputs(true) @@ -35,9 +35,9 @@ TORCH_META_FUNC(threshold_backward)(const Tensor& grad, const Tensor& self, cons const Tensor& gradInput = maybe_get_output(); build(TensorIteratorConfig() .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay - .add_output(gradInput) - .add_input(self) - .add_input(grad) // other + .add_borrowed_output(gradInput) + .add_borrowed_input(self) + .add_borrowed_input(grad) // other .allow_cpu_scalars(true) .promote_inputs_to_common_dtype(true) .cast_common_dtype_to_outputs(true) @@ -50,26 +50,83 @@ TORCH_META_FUNC(elu) ( build_unary_op(maybe_get_output(), self); } +TORCH_META_FUNC(elu_backward) ( + const Tensor& grad_output, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result, + const Tensor& self_or_result +) { + TORCH_CHECK( + !is_result || alpha.to() >= 0.0, + "In-place elu backward calculation is triggered with a negative slope which is not supported. " + "This is caused by calling in-place forward function with a negative slope, " + "please call out-of-place version instead."); + + build_borrowing_binary_op(maybe_get_output(), grad_output, self_or_result); +} + TORCH_META_FUNC(silu) (const Tensor& self) { build_unary_op(maybe_get_output(), self); } +TORCH_META_FUNC(mish) (const Tensor& self) { + build_unary_op(maybe_get_output(), self); +} + TORCH_META_FUNC(softplus) ( const Tensor& self, const Scalar& beta, const Scalar& threshold ) { build_unary_op(maybe_get_output(), self); } +TORCH_META_FUNC(softplus_backward) ( + const Tensor& grad_output, + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + const Tensor& output +) { + build_borrowing_binary_op(maybe_get_output(), grad_output, self); +} + TORCH_META_FUNC(leaky_relu) ( const Tensor& self, const Scalar& negval ) { build_unary_op(maybe_get_output(), self); } +// Note: leakyReLu backward calculation doesn't support in-place call with negative slope. +// The reason is that for in-place forward call, the forward result will be saved into autograd +// node instead of the input itself, when calculating backward gradient, there is no way to know +// whether the original input for current node is positive or not if the input slope is +// negative. eg. forward is 2, slope is -0.2, the original input for this node could be +// either 2, or -10, so no way to get a correct backward gradient in this case. +TORCH_META_FUNC(leaky_relu_backward) ( + const Tensor& grad_output, + const Tensor& self_or_result, + const Scalar& negval, + bool is_result +) { + TORCH_CHECK( + !is_result || negval.to() >= 0.0, + "In-place leakyReLu backward calculation is triggered with a negative slope which is not supported. " + "This is caused by calling in-place forward function with a negative slope, " + "please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do " + "require supporting in-place leakRelu backward calculation with negative slope"); + + build_borrowing_binary_op(maybe_get_output(), self_or_result, grad_output); +} + TORCH_META_FUNC(hardsigmoid) (const Tensor& self) { build_unary_op(maybe_get_output(), self); } +TORCH_META_FUNC(hardsigmoid_backward) (const Tensor& grad_output, const Tensor& self) { + build_borrowing_binary_op(maybe_get_output(), grad_output, self); +} + static inline void softshrink_check(const Scalar& lambd) { double lamb = lambd.to(); TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, "."); @@ -127,6 +184,10 @@ DEFINE_DISPATCH(leaky_relu_backward_stub); DEFINE_DISPATCH(silu_stub); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(silu_backward_stub); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(mish_stub); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(mish_backward_stub); TORCH_IMPL_FUNC(elu_out) ( const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result @@ -134,30 +195,75 @@ TORCH_IMPL_FUNC(elu_out) ( elu_stub(device_type(), *this, alpha, scale, input_scale); } +TORCH_IMPL_FUNC(elu_backward_out) ( + const Tensor& grad_output, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result, + const Tensor& self_or_result, + const Tensor& grad_input +) { + elu_backward_stub(device_type(), *this, alpha, scale, input_scale, is_result); +} + TORCH_IMPL_FUNC(silu_out) ( const Tensor& self, const Tensor& result ) { silu_stub(device_type(), *this); } +TORCH_IMPL_FUNC(mish_out) ( + const Tensor& self, const Tensor& result +) { + mish_stub(device_type(), *this); +} + TORCH_IMPL_FUNC(softplus_out) ( const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result ) { softplus_stub(device_type(), *this, beta, threshold); } +TORCH_IMPL_FUNC(softplus_backward_out) ( + const Tensor& grad_output, + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + const Tensor& output, + const Tensor& grad_input +) { + softplus_backward_stub(device_type(), *this, beta, threshold); +} + TORCH_IMPL_FUNC(leaky_relu_out) ( const Tensor& self, const Scalar& negval, const Tensor& result ) { leaky_relu_stub(device_type(), *this, negval); } +TORCH_IMPL_FUNC(leaky_relu_backward_out) ( + const Tensor& grad_output, + const Tensor& self_or_result, + const Scalar& negval, + bool is_result, + const Tensor& grad_input +) { + leaky_relu_backward_stub(device_type(), *this, negval); +} + TORCH_IMPL_FUNC(hardsigmoid_out) ( const Tensor& self, const Tensor& result ) { hardsigmoid_stub(device_type(), *this); } +TORCH_IMPL_FUNC(hardsigmoid_backward_out) ( + const Tensor& grad_output, const Tensor& self, const Tensor& grad_input +) { + hardsigmoid_backward_stub(device_type(), *this); +} + TORCH_IMPL_FUNC(softshrink_out) ( const Tensor & self, const Scalar& lambd, const Tensor& result ) { @@ -177,44 +283,18 @@ Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) { } Tensor& hardtanh_backward_out(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max, Tensor& grad_input) { - auto iter = TensorIterator::binary_op(grad_input, grad_output, self); + auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self); hardtanh_backward_stub(iter.device_type(), iter, min, max); return grad_input; } Tensor hardtanh_backward(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max) { Tensor result; - auto iter = TensorIterator::binary_op(result, grad_output, self); + auto iter = TensorIterator::borrowing_binary_op(result, grad_output, self); hardtanh_backward_stub(iter.device_type(), iter, min, max); return iter.output(); } -Tensor hardsigmoid_backward(const Tensor& grad_output, const Tensor& self) { - Tensor result; - auto iter = TensorIterator::binary_op(result, grad_output, self); - hardsigmoid_backward_stub(iter.device_type(), iter); - return iter.output(); -} - -Tensor elu_backward( - const Tensor& grad_output, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - bool is_result, - const Tensor& self_or_result) { - TORCH_CHECK( - !is_result || alpha.to() >= 0.0, - "In-place elu backward calculation is triggered with a negative slope which is not supported. " - "This is caused by calling in-place forward function with a negative slope, " - "please call out-of-place version instead."); - - Tensor result; - auto iter = TensorIterator::binary_op(result, grad_output, self_or_result); - elu_backward_stub(iter.device_type(), iter, alpha, scale, input_scale, is_result); - return iter.output(); -} - Tensor hardswish(const Tensor& self) { #if defined(C10_MOBILE) && defined(USE_XNNPACK) if (xnnpack::use_hardswish(self)) { @@ -247,7 +327,7 @@ Tensor& hardswish_(Tensor& self) { Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) { Tensor grad_input; - auto iter = TensorIterator::binary_op(grad_input, grad_output, self); + auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, self); hardswish_backward_stub(iter.device_type(), iter); return iter.output(); } @@ -294,7 +374,7 @@ Tensor silu_backward( const Tensor& grad_output, const Tensor& input) { Tensor grad_input = at::empty({0}, input.options()); - auto iter = TensorIterator::binary_op(grad_input, grad_output, input); + auto iter = TensorIterator::borrowing_binary_op(grad_input, grad_output, input); silu_backward_stub(iter.device_type(), iter); return grad_input; } @@ -306,6 +386,23 @@ Tensor math_silu_backward( return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid))); } +Tensor mish_backward( + const Tensor& grad_output, + const Tensor& input) { + Tensor grad_input = at::empty({0}, input.options()); + auto iter = TensorIterator::binary_op(grad_input, grad_output, input); + mish_backward_stub(iter.device_type(), iter); + return grad_input; +} + +Tensor math_mish_backward( + const Tensor& grad_output, + const Tensor& input) { + auto input_tanh_softplus = at::tanh(at::softplus(input)); + auto input_sigmoid = at::sigmoid(input); + return grad_output * (input_tanh_softplus + (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus))); +} + template inline void _rrelu_with_noise_train( Tensor& output, @@ -409,29 +506,6 @@ Tensor & rrelu_(Tensor & self, const Scalar& lower, const Scalar& upper, bool tr return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, generator); } -Tensor & softplus_backward_out(const Tensor& grad_output, - const Tensor& self, - const Scalar& beta, - const Scalar& threshold, - const Tensor& output, - Tensor& grad_input) { - auto iter = TensorIterator::binary_op(grad_input, grad_output, self); - softplus_backward_stub(iter.device_type(), iter, beta, threshold); - return grad_input; -} - -Tensor softplus_backward( - const Tensor& grad_output, - const Tensor& self, - const Scalar& beta, - const Scalar& threshold, - const Tensor& output) { - Tensor grad_input; - auto iter = TensorIterator::binary_op(grad_input, grad_output, self); - softplus_backward_stub(iter.device_type(), iter, beta, threshold); - return iter.output(); -} - TORCH_IMPL_FUNC(threshold_out)(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) { threshold_stub(device_type(), *this, threshold, value); } @@ -703,21 +777,21 @@ Tensor hardshrink(const Tensor & self, const Scalar& lambd) { Tensor hardshrink_backward(const Tensor & grad, const Tensor & self, const Scalar& lambd) { auto out_tensor = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto iter = TensorIterator::binary_op(out_tensor, grad, self); + auto iter = TensorIterator::borrowing_binary_op(out_tensor, grad, self); shrink_backward_stub(iter.device_type(), iter, lambd); return out_tensor; } Tensor& softshrink_backward_out(const Tensor & grad, const Tensor & self, const Scalar& lambd, Tensor& grad_input) { - auto iter = TensorIterator::binary_op(grad_input, grad, self); + auto iter = TensorIterator::borrowing_binary_op(grad_input, grad, self); shrink_backward_stub(iter.device_type(), iter, lambd); return grad_input; } Tensor softshrink_backward(const Tensor & grad, const Tensor & self, const Scalar& lambd) { Tensor result; - auto iter = TensorIterator::binary_op(result, grad, self); + auto iter = TensorIterator::borrowing_binary_op(result, grad, self); shrink_backward_stub(iter.device_type(), iter, lambd); return iter.output(); } @@ -743,7 +817,7 @@ Tensor gelu_backward_cpu(const Tensor& grad, const Tensor& self) { c10::nullopt /* device */, c10::nullopt /* pin_memory */, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto it = TensorIterator::binary_op(dX, grad, self); + auto it = TensorIterator::borrowing_binary_op(dX, grad, self); GeluBackwardKernel(kCPU, it); return dX; } @@ -757,30 +831,6 @@ Tensor infinitely_differentiable_gelu_backward( return cdf.addcmul_(self, pdf, kAlpha).mul_(grad); } -// Note: leakyReLu backward calculation doesn't support in-place call with negative slope. -// The reason is that for in-place forward call, the forward result will be saved into autograd -// node instead of the input itself, when calculating backward gradient, there is no way to know -// whether the original input for current node is positive or not if the input slope is -// negative. eg. forward is 2, slope is -0.2, the original input for this node could be -// either 2, or -10, so no way to get a correct backward gradient in this case. -Tensor leaky_relu_backward( - const Tensor& grad_output, - const Tensor& self_or_result, - const Scalar& negval, - bool is_result) { - TORCH_CHECK( - !is_result || negval.to() >= 0.0, - "In-place leakyReLu backward calculation is triggered with a negative slope which is not supported. " - "This is caused by calling in-place forward function with a negative slope, " - "please call out-of-place version instead. File an issue at https://github.com/pytorch/pytorch if you do " - "require supporting in-place leakRelu backward calculation with negative slope"); - - Tensor result; - auto iter = TensorIterator::binary_op(result, self_or_result, grad_output); - leaky_relu_backward_stub(iter.device_type(), iter, negval); - return iter.output(); -} - std::tuple log_sigmoid_forward_cpu(const Tensor& input) { // FIXME: do these actually need to be zeros_like or can they be empty_like? auto result = at::zeros_like(input, at::MemoryFormat::Contiguous); @@ -813,10 +863,10 @@ Tensor log_sigmoid(const Tensor & self) { Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) { Tensor grad_input; auto iter = at::TensorIteratorConfig() - .add_output(grad_input) - .add_input(input) - .add_input(buffer) - .add_input(grad_output) + .add_borrowed_output(grad_input) + .add_borrowed_input(input) + .add_borrowed_input(buffer) + .add_borrowed_input(grad_output) .build(); log_sigmoid_backward_cpu_stub(kCPU, iter); return iter.output(); @@ -827,10 +877,10 @@ Tensor& log_sigmoid_backward_out_cpu(const Tensor& grad_output, const Tensor& buffer, Tensor& grad_input) { auto iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(input) - .add_input(buffer) - .add_input(grad_output) + .add_borrowed_output(grad_input) + .add_borrowed_input(input) + .add_borrowed_input(buffer) + .add_borrowed_input(grad_output) .build(); log_sigmoid_backward_cpu_stub(kCPU, iter); return grad_input; diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index 12072681abe9d2..96a94314adacbe 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -15,20 +15,20 @@ using structured_activation_fn = void (*)(TensorIteratorBase&); using activation_fn = void (*)(TensorIterator&); using activation_backward_fn = void (*)(TensorIterator&); using softplus_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&); -using softplus_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&); +using softplus_backward_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&); using threshold_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&); using hardtanh_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&); using hardsigmoid_fn = void(*)(TensorIteratorBase&); -using hardsigmoid_backward_fn = void(*)(TensorIterator&); +using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&); using hardswish_fn = void(*)(TensorIterator&); using hardswish_backward_fn = void(*)(TensorIterator&); using shrink_fn = void (*)(TensorIterator&, const Scalar&); using softshrink_fn = void (*)(TensorIteratorBase&, const Scalar&); using shrink_backward_fn = void (*)(TensorIterator&, const Scalar&); using elu_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&, const Scalar&); -using elu_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&, bool); +using elu_backward_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&, const Scalar&, bool); using leaky_relu_fn = void (*)(TensorIteratorBase&, const Scalar&); -using leaky_relu_backward_fn = void (*)(TensorIterator&, const Scalar&); +using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const Scalar&); using log_sigmoid_cpu_fn = void (*)(Tensor& , Tensor&, const Tensor& ); DECLARE_DISPATCH(elu_fn, elu_stub); @@ -54,6 +54,8 @@ DECLARE_DISPATCH(activation_fn, glu_stub); DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub); DECLARE_DISPATCH(structured_activation_fn, silu_stub); DECLARE_DISPATCH(activation_backward_fn, silu_backward_stub); +DECLARE_DISPATCH(structured_activation_fn, mish_stub); +DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub); } // namespace native diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 33207b04a3c4b9..6a709ea53a7b32 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1877,7 +1877,7 @@ void linalg_qr_out_helper(const Tensor& input, const Tensor& Q, const Tensor& R, orgqr_stub(input.device().type(), const_cast(Q), tau); } -std::tuple _linalg_qr_helper_default(const Tensor& input, std::string mode) { +std::tuple _linalg_qr_helper_default(const Tensor& input, c10::string_view mode) { bool compute_q, reduced_mode; std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); auto m = input.size(-2); @@ -1908,13 +1908,13 @@ std::tuple _linalg_qr_helper_default(const Tensor& input, std::s return std::make_tuple(Q, R); } -std::tuple linalg_qr(const Tensor& self, std::string mode) { +std::tuple linalg_qr(const Tensor& self, c10::string_view mode) { TORCH_CHECK(self.dim() >= 2, "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_linalg_qr_helper(self, mode); } -std::tuple linalg_qr_out(const Tensor& self, std::string mode, Tensor& Q, Tensor& R) { +std::tuple linalg_qr_out(const Tensor& self, c10::string_view mode, Tensor& Q, Tensor& R) { TORCH_CHECK(self.dim() >= 2, "torch.linalg.qr: input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); checkSameDevice("torch.linalg.qr", Q, self, "Q"); @@ -1938,7 +1938,7 @@ std::tuple qr(const Tensor& self, bool some) { "should be replaced with\n", "Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete')" ); - std::string mode = some ? "reduced" : "complete"; + const char* mode = some ? "reduced" : "complete"; return at::linalg_qr(self, mode); } @@ -1950,7 +1950,7 @@ std::tuple qr_out(const Tensor& self, bool some, Tensor& Q, Ten "should be replaced with\n", "Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete')" ); - std::string mode = some ? "reduced" : "complete"; + const char* mode = some ? "reduced" : "complete"; return at::linalg_qr_out(Q, R, self, mode); } @@ -2260,7 +2260,7 @@ std::tuple linalg_eigh_out_info( Tensor& vectors, Tensor& infos, bool compute_eigenvectors, - const std::string& uplo_str) { + const c10::string_view uplo_str) { // These internal asserts make explicit the assumptions in the implementation // Error check with the actual error messages are done on the higher level of // the hierarchy of calls @@ -2318,7 +2318,7 @@ std::tuple linalg_eigh_out_info( return std::tuple(values, vectors); } -std::tuple linalg_eigh(const Tensor& input, std::string uplo) { +std::tuple linalg_eigh(const Tensor& input, c10::string_view uplo) { squareCheckInputs(input); checkUplo(uplo); ScalarType real_dtype = toValueType(input.scalar_type()); @@ -2339,7 +2339,7 @@ std::tuple linalg_eigh(const Tensor& input, std::string uplo) { // TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out // TODO: implement _out variant avoiding copy and using already allocated storage directly -std::tuple linalg_eigh_out(const Tensor& input, std::string uplo, Tensor& eigvals, Tensor& eigvecs) { +std::tuple linalg_eigh_out(const Tensor& input, c10::string_view uplo, Tensor& eigvals, Tensor& eigvecs) { checkSameDevice("torch.linalg.eigh", eigvecs, input, "eigenvectors"); checkSameDevice("torch.linalg.eigh", eigvals, input, "eigenvalues"); checkLinalgCompatibleDtype("torch.linalg.eigh", eigvecs, input, "eigenvectors"); @@ -2359,7 +2359,7 @@ std::tuple linalg_eigh_out(const Tensor& input, std::string up return std::tuple(eigvals, eigvecs); } -Tensor linalg_eigvalsh(const Tensor& input, std::string uplo) { +Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) { // if input requires grad we must compute the eigenvectors to make this function differentiable // the eigenvectors are not exposed to the user if (at::GradMode::is_enabled() && input.requires_grad()) { @@ -2388,7 +2388,7 @@ Tensor linalg_eigvalsh(const Tensor& input, std::string uplo) { // TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out // TODO: implement _out variant avoiding copy and using already allocated storage directly -Tensor& linalg_eigvalsh_out(const Tensor& input, std::string uplo, Tensor& result) { +Tensor& linalg_eigvalsh_out(const Tensor& input, c10::string_view uplo, Tensor& result) { checkSameDevice("torch.linalg.eigvalsh", result, input); ScalarType real_dtype = toValueType(input.scalar_type()); checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype); @@ -3345,17 +3345,17 @@ static void linalg_lstsq_out_info( } } -static std::string get_default_lstsq_driver(c10::optional driver, const Tensor& input) { +static std::string get_default_lstsq_driver(c10::optional driver, const Tensor& input) { // if `driver` is empty, we set driver_str to "gels" if working with CUDA tensors, // otherwise to "gelsy" driver. std::string driver_str; // check whether the user provided name is a valid driver name if (driver.has_value()) { - driver_str = driver.value(); + driver_str = std::string(driver.value()); // convert `driver_str` to lower case inplace. std::transform(driver_str.begin(), driver_str.end(), driver_str.begin(), [](unsigned char c) { return std::tolower(c); }); - static std::unordered_set allowed_drivers = { + static std::unordered_set allowed_drivers = { "gels", "gelsy", "gelsd", "gelss" }; if (input.device() == at::kCPU) { @@ -3382,7 +3382,7 @@ std::tuple linalg_lstsq_out( const Tensor& input, const Tensor& other, c10::optional rcond, - c10::optional driver, + c10::optional driver, Tensor& solution, Tensor& residuals, Tensor& rank, @@ -3550,7 +3550,7 @@ std::tuple linalg_lstsq_out( std::tuple linalg_lstsq( const Tensor& input, const Tensor& other, c10::optional rcond, - c10::optional driver) { + c10::optional driver) { Tensor solution = at::empty({0}, input.options()); Tensor residuals = at::empty({0}, input.options().dtype(toValueType(input.scalar_type()))); Tensor rank = at::empty({0}, input.options().dtype(at::kLong)); diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 1a832b0645c215..8c3a4c15d9864a 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -696,7 +696,7 @@ void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_valu // This is a type and driver dispatching helper function for 'apply_lstsq' void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name) { - static auto driver_string_to_type = std::unordered_map({ + static auto driver_string_to_type = std::unordered_map({ {"gels", at::native::LapackLstsqDriverType::Gels}, {"gelsy", at::native::LapackLstsqDriverType::Gelsy}, {"gelsd", at::native::LapackLstsqDriverType::Gelsd}, diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 8b6f47c7155d10..263b829d6809fa 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -38,7 +38,7 @@ TORCH_META_FUNC2(div, Tensor) (const Tensor& self, const Tensor& other) { build_borrowing_binary_float_op(maybe_get_output(), self, other); } -TORCH_META_FUNC2(div, Tensor_mode) (const Tensor& self, const Tensor& other, c10::optional rounding_mode) { +TORCH_META_FUNC2(div, Tensor_mode) (const Tensor& self, const Tensor& other, c10::optional rounding_mode) { if (!rounding_mode.has_value()) { build_borrowing_binary_float_op(maybe_get_output(), self, other); // NOLINTNEXTLINE(bugprone-branch-clone) @@ -80,6 +80,10 @@ TORCH_META_FUNC(atan2) (const Tensor& self, const Tensor& other) { build_borrowing_binary_float_op(maybe_get_output(), self, other); } +TORCH_META_FUNC2(remainder, Tensor)(const Tensor& self, const Tensor& other) { + build_borrowing_binary_op(maybe_get_output(), self, other); +} + // These are normal binary ops that preserve dtype #define CREATE_BINARY_META_FUNC(func) \ TORCH_META_FUNC(func) (const Tensor& self, const Tensor& other) { \ @@ -231,7 +235,7 @@ TORCH_IMPL_FUNC(div_out) (const Tensor& self, const Tensor& other, const Tensor& } TORCH_IMPL_FUNC(div_out_mode) ( - const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& result + const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& result ) { if (!rounding_mode.has_value()) { div_true_stub(device_type(), *this); @@ -263,6 +267,7 @@ CREATE_BINARY_TORCH_IMPL_FUNC(hypot); CREATE_BINARY_TORCH_IMPL_FUNC(igamma); CREATE_BINARY_TORCH_IMPL_FUNC(igammac); CREATE_BINARY_TORCH_IMPL_FUNC(nextafter); +CREATE_BINARY_TORCH_IMPL_FUNC(remainder); Tensor special_xlog1py(const Scalar& x, const Tensor& y) { return at::special_xlog1py(wrapped_scalar_tensor(x), y); @@ -365,11 +370,11 @@ Tensor& div_(Tensor& self, const Scalar& other) { return self.div_(wrapped_scalar_tensor(other)); // redispatch! } -Tensor div(const Tensor& self, const Scalar& other, c10::optional rounding_mode) { +Tensor div(const Tensor& self, const Scalar& other, c10::optional rounding_mode) { return self.div(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch! } -Tensor& div_(Tensor& self, const Scalar& other, c10::optional rounding_mode) { +Tensor& div_(Tensor& self, const Scalar& other, c10::optional rounding_mode) { return self.div_(wrapped_scalar_tensor(other), std::move(rounding_mode)); // redispatch! } @@ -394,23 +399,23 @@ Tensor& divide_(Tensor& self, const Scalar& other) { return self.div_(other); } -Tensor& divide_out(const Tensor& self, const Tensor& other, c10::optional rounding_mode, Tensor& result) { +Tensor& divide_out(const Tensor& self, const Tensor& other, c10::optional rounding_mode, Tensor& result) { return at::div_out(result, self, other, std::move(rounding_mode)); } -Tensor divide(const Tensor& self, const Tensor& other, c10::optional rounding_mode) { +Tensor divide(const Tensor& self, const Tensor& other, c10::optional rounding_mode) { return self.div(other, std::move(rounding_mode)); } -Tensor& divide_(Tensor& self, const Tensor& other, c10::optional rounding_mode) { +Tensor& divide_(Tensor& self, const Tensor& other, c10::optional rounding_mode) { return self.div_(other, std::move(rounding_mode)); } -Tensor divide(const Tensor& self, const Scalar& other, c10::optional rounding_mode) { +Tensor divide(const Tensor& self, const Scalar& other, c10::optional rounding_mode) { return self.div(other, std::move(rounding_mode)); } -Tensor& divide_(Tensor& self, const Scalar& other, c10::optional rounding_mode) { +Tensor& divide_(Tensor& self, const Scalar& other, c10::optional rounding_mode) { return self.div_(other, std::move(rounding_mode)); } @@ -435,23 +440,6 @@ Tensor& true_divide_(Tensor& self, const Scalar& divisor) { return self.div_(divisor); } -Tensor& remainder_out(const Tensor& self, const Tensor& other, Tensor& result) { - auto iter = TensorIterator::binary_op(result, self, other); - remainder_stub(iter.device_type(), iter); - return result; -} - -Tensor remainder(const Tensor& self, const Tensor& other) { - Tensor result; - auto iter = TensorIterator::binary_op(result, self, other); - remainder_stub(iter.device_type(), iter); - return iter.output(); -} - -Tensor& remainder_(Tensor& self, const Tensor& other) { - return native::remainder_out(self, other, self); -} - Tensor& floor_divide_out(const Tensor& self, const Tensor& other, Tensor& result) { TORCH_WARN_ONCE( "floor_divide is deprecated, and will be removed in a future version of pytorch. " @@ -627,15 +615,18 @@ Tensor& add_(Tensor& self, const Scalar& other, const Scalar& alpha) { } Tensor remainder(const Tensor& self, const Scalar& other) { - return native::remainder(self, wrapped_scalar_tensor(other)); + // redispatch + return at::remainder(self, wrapped_scalar_tensor(other)); } Tensor& remainder_(Tensor& self, const Scalar& other) { - return native::remainder_(self, wrapped_scalar_tensor(other)); + // redispatch + return self.remainder_(wrapped_scalar_tensor(other)); } Tensor& remainder_out(const Tensor& self, const Scalar& other, Tensor& result) { - return native::remainder_out(self, wrapped_scalar_tensor(other), result); + // redispatch + return at::remainder_out(result, self, wrapped_scalar_tensor(other)); } Tensor rsub(const Tensor& self, const Scalar& other, const Scalar& alpha) { diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index e7d36d9298842a..de95035c81f643 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -53,7 +53,7 @@ DECLARE_DISPATCH(structured_binary_fn, div_true_stub); DECLARE_DISPATCH(structured_binary_fn, div_floor_stub); DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub); DECLARE_DISPATCH(structured_binary_fn, atan2_stub); -DECLARE_DISPATCH(binary_fn, remainder_stub); +DECLARE_DISPATCH(structured_binary_fn, remainder_stub); DECLARE_DISPATCH(binary_fn, bitwise_and_stub); DECLARE_DISPATCH(binary_fn, bitwise_or_stub); DECLARE_DISPATCH(binary_fn, bitwise_xor_stub); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index d774af9031d7f1..712f500a73f46e 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -678,7 +678,7 @@ static Tensor convolution_same( Tensor _convolution_mode( const Tensor& input, const Tensor& weight, const c10::optional& bias_opt, - IntArrayRef stride, std::string padding, IntArrayRef dilation, + IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, int64_t groups) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); @@ -698,7 +698,7 @@ Tensor _convolution_mode( at::Tensor conv1d( const Tensor& input, const Tensor& weight, const c10::optional& bias, - IntArrayRef stride, std::string padding, IntArrayRef dilation, + IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, int64_t groups) { return at::_convolution_mode( input, weight, bias, stride, std::move(padding), dilation, groups); @@ -706,7 +706,7 @@ at::Tensor conv1d( at::Tensor conv2d( const Tensor& input, const Tensor& weight, const c10::optional& bias, - IntArrayRef stride, std::string padding, IntArrayRef dilation, + IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, int64_t groups) { return at::_convolution_mode( input, weight, bias, stride, std::move(padding), dilation, groups); @@ -714,7 +714,7 @@ at::Tensor conv2d( at::Tensor conv3d( const Tensor& input, const Tensor& weight, const c10::optional& bias, - IntArrayRef stride, std::string padding, IntArrayRef dilation, + IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, int64_t groups) { return at::_convolution_mode( input, weight, bias, stride, std::move(padding), dilation, groups); diff --git a/aten/src/ATen/native/DistributionTemplates.h b/aten/src/ATen/native/DistributionTemplates.h index ad7c9a38e76347..1a071d11fd8964 100644 --- a/aten/src/ATen/native/DistributionTemplates.h +++ b/aten/src/ATen/native/DistributionTemplates.h @@ -63,7 +63,7 @@ int64_t update_to(int64_t to) { template class random_kernel, typename RNG> at::Tensor& random_impl(at::Tensor& self, c10::optional generator) { - auto iter = at::TensorIterator::nullary_op(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); random_kernel()(iter, generator); return self; } @@ -106,7 +106,7 @@ static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMet template class random_from_to_kernel, typename RNG> at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional to_opt, c10::optional generator) { uint64_t range = 0; - auto iter = at::TensorIterator::nullary_op(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); if (to_opt.has_value()) { // [from, to) int64_t to = *to_opt; @@ -299,7 +299,7 @@ at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, c10::optiona from = std::min(std::max(from, min), max); to = std::max(std::min(to, max), min); }); - auto iter = at::TensorIterator::nullary_op(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); uniform_kernel()(iter, from, to, generator); } return self; @@ -310,7 +310,7 @@ at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, c10::optiona template class log_normal_kernel, typename RNG> at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::optional gen) { TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std); - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); log_normal_kernel()(iter, mean, std, gen); return self; } @@ -320,7 +320,7 @@ at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::opt template class geometric_kernel, typename RNG> Tensor& geometric_impl_(Tensor& self, double p, c10::optional gen) { TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p); - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); geometric_kernel()(iter, p, gen); return self; } @@ -330,7 +330,7 @@ Tensor& geometric_impl_(Tensor& self, double p, c10::optional gen) { template class exponential_kernel, typename RNG> Tensor& exponential_impl_(Tensor& self, double lambda, c10::optional gen) { TORCH_CHECK(lambda >= 0.0, "exponential_ expects lambda >= 0.0, but found lambda=", lambda); - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); exponential_kernel()(iter, lambda, gen); return self; } @@ -339,7 +339,7 @@ Tensor& exponential_impl_(Tensor& self, double lambda, c10::optional template class cauchy_kernel, typename RNG> Tensor& cauchy_impl_(Tensor& self, double median, double sigma, c10::optional gen) { - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); cauchy_kernel()(iter, median, sigma, gen); return self; } diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 5f98c1c6060a9e..6ad8495b19026e 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -360,9 +360,9 @@ Tensor& random_meta_(Tensor& self, int64_t to, c10::optional gen) { Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { Tensor ret = at::empty(self.sizes(), self.options()); auto iter = TensorIteratorConfig() - .add_output(ret) - .add_input(self) - .add_input(output) + .add_borrowed_output(ret) + .add_borrowed_input(self) + .add_borrowed_input(output) .build(); AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] { cpu_serial_kernel(iter, [](scalar_t self_val, scalar_t output_val) -> scalar_t{ @@ -375,10 +375,10 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& total) { Tensor ret = at::empty(x.sizes(), x.options()); auto iter = TensorIteratorConfig() - .add_output(ret) - .add_input(x) - .add_input(alpha) - .add_input(total) + .add_borrowed_output(ret) + .add_borrowed_input(x) + .add_borrowed_input(alpha) + .add_borrowed_input(total) .build(); AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "_dirichlet_grad_cpu", [&] { cpu_serial_kernel(iter, [](scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t{ @@ -395,9 +395,9 @@ Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& t Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, c10::optional gen) { Tensor ret = at::zeros(count.sizes(), count.options()); auto iter = TensorIteratorConfig() - .add_output(ret) - .add_input(count) - .add_input(prob) + .add_borrowed_output(ret) + .add_borrowed_input(count) + .add_borrowed_input(prob) .build(); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "binomial_cpu", [&] { CPUGeneratorImpl* generator = get_generator_or_default(gen, detail::getDefaultCPUGenerator()); @@ -420,8 +420,8 @@ Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, c10::optional gen) { Tensor ret = at::zeros(lambda.sizes(), lambda.options()); auto iter = TensorIteratorConfig() - .add_output(ret) - .add_input(lambda) + .add_borrowed_output(ret) + .add_borrowed_input(lambda) .build(); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "poisson_cpu", [&] { CPUGeneratorImpl* generator = get_generator_or_default(gen, detail::getDefaultCPUGenerator()); @@ -437,8 +437,8 @@ Tensor _s_poisson_cpu(const Tensor& lambda, c10::optional gen) { Tensor _s_gamma_cpu(const Tensor& alpha, c10::optional gen) { Tensor ret = at::zeros(alpha.sizes(), alpha.options()); auto iter = TensorIteratorConfig() - .add_output(ret) - .add_input(alpha) + .add_borrowed_output(ret) + .add_borrowed_input(alpha) .build(); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] { CPUGeneratorImpl* generator = get_generator_or_default(gen, detail::getDefaultCPUGenerator()); @@ -473,8 +473,8 @@ Tensor _s_dirichlet_cpu(const Tensor& alpha, c10::optional gen) { std::lock_guard lock(generator->mutex_); /* Generate gamma sample by casting alpha to double to prevent underflow. */ auto iter1 = TensorIteratorConfig() - .add_output(gamma) - .add_input(alpha) + .add_borrowed_output(gamma) + .add_borrowed_input(alpha) .check_all_same_dtype(false) .build(); cpu_serial_kernel(iter1, [generator](scalar_t alpha_val) -> double{ @@ -496,9 +496,9 @@ Tensor _s_dirichlet_cpu(const Tensor& alpha, c10::optional gen) { /* Normalize and cast back to scalar_t. */ Tensor gamma_sum = gamma.sum(-1, true).expand(alpha.sizes()); auto iter2 = TensorIteratorConfig() - .add_output(ret) - .add_input(gamma) - .add_input(gamma_sum) + .add_borrowed_output(ret) + .add_borrowed_input(gamma) + .add_borrowed_input(gamma_sum) .check_all_same_dtype(false) .build(); cpu_serial_kernel(iter2, [](double gamma_val, double gamma_sum_val) -> scalar_t{ diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index b1c65e1ff40b1f..1bab9ec6beea40 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -25,7 +25,7 @@ Tensor& fill_out(Tensor& self, const Scalar& value) { auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay .check_all_same_dtype(false) - .add_output(self) + .add_borrowed_output(self) .resize_outputs(false) .build(); fill_stub(iter.device_type(), iter, value); diff --git a/aten/src/ATen/native/Fill.h b/aten/src/ATen/native/Fill.h index c8d06ca57c9dea..e1903a379a0c3f 100644 --- a/aten/src/ATen/native/Fill.h +++ b/aten/src/ATen/native/Fill.h @@ -10,4 +10,6 @@ namespace at { namespace native { DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&), fill_stub); +Tensor& fill_out(Tensor& self, const Scalar& value); + }} // namespace at::native diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index 9c4215f9a77325..ca50fc6f90f20f 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -6,17 +6,23 @@ namespace at { namespace native { namespace { +// Check if tensor list has either a boolean tensor or a integer tensor +bool has_integral_tensor(TensorList tensors, const bool includeBool) { + return std::any_of(tensors.begin(), tensors.end(), + [&includeBool](const auto & t) { return at::isIntegralType(t.scalar_type(), includeBool); }); +} +// check if tensor list has bool tensors +bool has_bool_tensor(TensorList tensors) { + return std::any_of(tensors.begin(), tensors.end(), + [](const auto & t) -> bool { return t.scalar_type() == ScalarType::Bool; }); +} + // Check foreach API restrictions // - Tensor lists must be non-empty. -// - All tensors in all lists must have the same dtype. // - All TensorLists and ScalarLists must have the same number of elements. // - Corresponding tensors must have the same size. void check_foreach_api_restrictions(TensorList tensors) { TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); - auto expected_dtype = tensors[0].dtype(); - for (const auto& t : tensors) { - TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); - } } void check_foreach_api_restrictions(TensorList tensors, ArrayRef scalars) { @@ -29,11 +35,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) { TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor."); TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size()); - auto expected_dtype = tensors1[0].dtype(); - for (const auto i : c10::irange(tensors1.size())) { - TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); - TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes()); } } @@ -45,11 +47,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size()); TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size()); - auto expected_dtype = tensors1[0].dtype(); - for (const auto i : c10::irange(tensors1.size())) { - TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); - TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes()); TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes()); } @@ -61,20 +59,24 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te } // To go via 'fast' path, several conditions must be satisfied +// - All tensors in all lists must have the same dtype. // - All tensors must be on the same device // - All tensors must have strided layout // - All tensors must be non-overlapping and dense // - Resulting tensor must have the same dtype as the input one +// TODO(mkozuki): Consider whether we really need this function or not. +// Note that, there is a possibility that foreach fastpath supports type promotion in the future, +// which might complicate the functionality this function should provides. +// However, as of now, the check of division op with integer inputs is duplicated. +// `check_fast_path_restrictions` does the same thing in it before calling this function. bool will_promote_tensor(const Tensor& tensor, const Scalar& scalar, bool does_op_promote_integer_inputs_to_float = false) { // In case of division, integer inputs will result in float - if (does_op_promote_integer_inputs_to_float) { - if (at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) { - return true; - } + if (does_op_promote_integer_inputs_to_float && + at::isIntegralType(tensor.scalar_type(), /* includeBool */ true)) { + return true; } - auto result_dtype = at::result_type(tensor, scalar); - return result_dtype != tensor.scalar_type(); + return tensor.scalar_type() != at::native::result_type(scalar, tensor); } // Please, make sure to call check_foreach_api_restrictions before calling this method. @@ -83,10 +85,12 @@ bool check_fast_path_restrictions( ArrayRef tensorLists, ArrayRef scalarList = {}, bool does_op_promote_integer_inputs_to_float = false) { - auto expected_device = tensorLists[0][0].device(); + const auto expected_dtype = tensorLists[0][0].dtype(); + const auto expected_device = tensorLists[0][0].device(); auto is_tensor_okay = [&](const Tensor& tensor) { - return tensor.device() == expected_device && + return tensor.dtype() == expected_dtype && + tensor.device() == expected_device && tensor.layout() == at::kStrided && tensor.is_non_overlapping_and_dense(); }; @@ -108,9 +112,11 @@ bool check_fast_path_restrictions( } } - // For all j, tensorList[j][0] have the same shape and dtype. (this was a precondition - // checked by `check_foreach_api_restrictions`). This means we only need to check if - // {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...} do type promotion with scalarLIst. + // This function has already checked that `tensorList[j][i]` for all j, i has the same dtype + // using `is_tensor_okay` function above. + // checked by `check_foreach_api_restrictions`). + // This means we only need to check if {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...} + // do type promotion with scalarLIst. for (int i=0; i < tensorLists[0].size(); i++) { if (does_op_promote_integer_inputs_to_float) { if (at::isIntegralType(tensorLists[0][i].scalar_type(), /*includeBool*/ true)) { @@ -123,6 +129,8 @@ bool check_fast_path_restrictions( return false; } } else if (scalarList.size() > 1) { + // FIXME(mkozuki): Consider specializing `TensorListScalarListMetadata` for complex dtypes + // to access the following comment. // Complex scalar list is not supported due to the limit for kernel launch argument (4KB) if (scalarList[i].isComplex()) { return false; diff --git a/aten/src/ATen/native/FunctionOfAMatrixUtils.cpp b/aten/src/ATen/native/FunctionOfAMatrixUtils.cpp index b30330c501579f..5b40c365ad4f30 100644 --- a/aten/src/ATen/native/FunctionOfAMatrixUtils.cpp +++ b/aten/src/ATen/native/FunctionOfAMatrixUtils.cpp @@ -86,9 +86,9 @@ Tensor& _compute_linear_combination_out(const Tensor& input, const Tensor& coeff .set_check_mem_overlap(false) // Output is intentionally 0 strided above .check_all_same_dtype(false) .resize_outputs(false) - .add_output(output_restrided) - .add_input(input_restrided) - .add_input(coefficients_restrided) + .add_borrowed_output(output_restrided) + .add_borrowed_input(input_restrided) + .add_borrowed_input(coefficients_restrided) .build(); // The dimension of size n is traversed inside the kernels, diff --git a/aten/src/ATen/native/GatedLinearUnit.cpp b/aten/src/ATen/native/GatedLinearUnit.cpp index 1086a6f3a74237..b8321ccbb545e6 100644 --- a/aten/src/ATen/native/GatedLinearUnit.cpp +++ b/aten/src/ATen/native/GatedLinearUnit.cpp @@ -28,7 +28,7 @@ Tensor& glu_out(const Tensor& self, int64_t dim, Tensor &result) { Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize); Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize); - auto iter = TensorIterator::binary_op(result, firstHalf, secondHalf); + auto iter = TensorIterator::borrowing_binary_op(result, firstHalf, secondHalf); glu_stub(iter.device_type(), iter); return result; } @@ -56,10 +56,10 @@ Tensor& glu_backward_out(const Tensor& grad_output, const Tensor& input, int64_t at::sigmoid_out(gradInputfirstHalf, secondHalf); // for second gradinput half, can get a better performance by fusion auto iter = at::TensorIteratorConfig() - .add_output(gradInputsecondHalf) - .add_input(gradInputfirstHalf) - .add_input(firstHalf) - .add_input(grad_output) + .add_borrowed_output(gradInputsecondHalf) + .add_borrowed_input(gradInputfirstHalf) + .add_borrowed_input(firstHalf) + .add_borrowed_input(grad_output) .build(); glu_backward_stub(iter.device_type(), iter); gradInputfirstHalf.mul_(grad_output); diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 19eeb9cde58e25..7f9a110b7a5ce3 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -148,12 +148,12 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra namespace { -bool einsum_check_label(char label) { +bool einsum_check_label(unsigned char label) { return std::isalpha(label); } -int einsum_label_to_index(char label) { - constexpr int NUM_OF_LETTERS = 'z' - 'a' + 1; +uint8_t einsum_label_to_index(unsigned char label) { + constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1; return std::islower(label) ? label - 'a' : NUM_OF_LETTERS + (label - 'A'); } @@ -164,12 +164,12 @@ int einsum_label_to_index(char label) { // 2. Unsqueeze missing dimensions from input operands and permute to align them // 3. Compute result by multiplying input operands and summing contraction // dimensions We do the last part by reducing to bmm. -Tensor einsum(std::string equation, TensorList operands) { +Tensor einsum(c10::string_view equation, TensorList operands) { TORCH_CHECK(!operands.empty(), "einsum(): must provide at least one operand"); checkDeviceType("einsum():", operands, operands[0].device().type()); // Code used to identify ELLIPSIS ("...") - constexpr int ELLIPSIS = '.'; + constexpr uint8_t ELLIPSIS = 52; // Find arrow (->) to split equation into lhs and rhs const auto arrow_pos = equation.find("->"); @@ -177,13 +177,14 @@ Tensor einsum(std::string equation, TensorList operands) { const auto num_ops = operands.size(); - // Convert labels for input operands into an index in [0, 25] and store + // Convert labels for input operands into an index in [0, 52) and store // them in op_labels for each operand along with ELLIPSIS if present. - std::vector> op_labels(num_ops); + std::vector> op_labels(num_ops); bool found_ell = false; std::size_t curr_op = 0; for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { - switch (lhs[i]) { + const unsigned char label = lhs[i]; + switch (label) { case ' ': // Ignore spaces break; @@ -217,12 +218,11 @@ Tensor einsum(std::string equation, TensorList operands) { default: // Parse label TORCH_CHECK( - einsum_check_label(lhs[i]), - "einsum(): operand subscript must be in [a-zA-Z] but found ", - lhs[i], - " for operand ", - curr_op); - op_labels[curr_op].push_back(einsum_label_to_index(lhs[i])); + einsum_check_label(label), + "einsum(): invalid subscript given at index ", + i, + " in the equation string, subscripts must be in [a-zA-Z]"); + op_labels[curr_op].push_back(einsum_label_to_index(label)); } } @@ -231,8 +231,8 @@ Tensor einsum(std::string equation, TensorList operands) { "einsum(): more operands were provided than specified in the equation"); // Labels must be within [a-zA-Z]. - constexpr int TOTAL_LABELS = 52; - std::vector label_count(TOTAL_LABELS, 0); + constexpr uint8_t TOTAL_LABELS = 52; + std::vector label_count(TOTAL_LABELS, 0); // The maximum number of dimensions covered by any ellipsis, needed when // unsqueezing missing dimensions from operands to permute and broadcast @@ -244,7 +244,7 @@ Tensor einsum(std::string equation, TensorList operands) { for(const auto i : c10::irange(num_ops)) { const auto operand = operands[i]; const auto labels = op_labels[i]; - const int64_t ndims = operand.dim(); + const auto ndims = operand.dim(); int64_t nlabels = labels.size(); bool has_ellipsis = false; @@ -295,7 +295,8 @@ Tensor einsum(std::string equation, TensorList operands) { // Parse explicit output const auto rhs = equation.substr(arrow_pos + 2); for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) { - switch (rhs[i]) { + const unsigned char label = rhs[i]; + switch (label) { case ' ': // Ignore spaces break; @@ -316,21 +317,21 @@ Tensor einsum(std::string equation, TensorList operands) { default: TORCH_CHECK( - einsum_check_label(rhs[i]), - "einsum(): subscripts must be in [a-zA-Z] but found ", - rhs[i], - " for the output"); - const auto label = einsum_label_to_index(rhs[i]); + einsum_check_label(label), + "einsum(): invalid subscript given at index ", + lhs.size() + 2 + i, + " in the equation string, subscripts must be in [a-zA-Z]"); + const auto index = einsum_label_to_index(label); TORCH_CHECK( // Ensure label appeared at least once for some input operand and at // most once for the output - label_count[label] > 0 && label_perm_index[label] == -1, + label_count[index] > 0 && label_perm_index[index] == -1, "einsum(): output subscript ", - rhs[i], - label_perm_index[label] > -1 + label, + label_perm_index[index] > -1 ? " appears more than once in the output" : " does not appear in the equation for any input operand"); - label_perm_index[label] = perm_index++; + label_perm_index[index] = perm_index++; } } } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 2dacbda7f8dd72..7fe4f2e18c8683 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -32,13 +32,26 @@ TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); - set_output(0, IntArrayRef({mat1.sizes()[0], mat2.sizes()[1]}), {}, self.options(), names); + set_output(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names); auto result = maybe_get_output(0); //this check can fire for inplace op only, for all other versions result is guaranteed to be correct size TORCH_CHECK(((result.dim() == 2) && (result.sizes()[0] == mat1.sizes()[0]) && (result.sizes()[1] == mat2.sizes()[1])), "The input tensor must be a matrix with size ", mat1.sizes()[0], "x", mat2.sizes()[1], ", but got a ", result.dim(), "-D tensor with size ", result.sizes()[0], "x", result.sizes()[1]); } + +TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) { + TORCH_CHECK(self.dim() == 2, "self must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + + auto names = at::namedinference::compute_matmul_outnames(self, mat2); + set_output(0, {self.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names); + auto result = maybe_get_output(0); + //this check can fire for inplace op only, for all other versions result is guaranteed to be correct size + TORCH_CHECK(((result.dim() == 2) && (result.sizes()[0] == self.sizes()[0]) && (result.sizes()[1] == mat2.sizes()[1])), + "The input tensor must be a matrix with size ", self.sizes()[0], "x", mat2.sizes()[1], ", but got a ", result.dim(), + "-D tensor with size ", result.sizes()[0], "x", result.sizes()[1]); +} } // namespace meta namespace native { @@ -767,10 +780,10 @@ static TensorIterator build_addr_iter(Tensor& result, auto iter = TensorIteratorConfig() .set_check_mem_overlap(true) - .add_output(result) + .add_borrowed_output(result) .add_input(*self_) .add_input(vec1.reshape({vec1_size0, 1})) - .add_input(vec2) + .add_borrowed_input(vec2) .allow_cpu_scalars(true) .promote_inputs_to_common_dtype(true) .cast_common_dtype_to_outputs(true) @@ -1108,19 +1121,11 @@ TORCH_IMPL_FUNC(addmm_out_cpu)(const Tensor& self, const Tensor& mat1, const Ten } } -Tensor& mm_cpu_out(const Tensor & self, const Tensor & mat2, Tensor & result) { - TORCH_CHECK(self.dim() == 2, "self must be a matrix"); - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); - native::resize_(result, {self.sizes()[0], mat2.sizes()[1]}); - addmm_impl_cpu_(result, result, self, mat2, 0, 1); - auto names = at::namedinference::propagate_names_for_addmm(self, mat2, result); - at::namedinference::propagate_names_if_nonempty(result, names); - return result; -} - -Tensor mm_cpu(const Tensor & self, const Tensor & mat2) { - Tensor result = at::empty({self.sizes()[0], mat2.sizes()[1]}, self.options()); - return native::mm_cpu_out(self, mat2, result); +TORCH_IMPL_FUNC(mm_out_cpu)(const Tensor & self, const Tensor & mat2, const Tensor & result) { + { + at::NoNamesGuard guard; + addmm_impl_cpu_(const_cast(result), result, self, mat2, 0, 1); + } } template @@ -1254,7 +1259,7 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& if (is_bmm_out) { for (int64_t b = 0; b < bs; b++) { auto r = self_or_result.select(0, b); - native::mm_cpu_out(batch1.select(0, b), batch2.select(0, b), r); + addmm_impl_cpu_(r, r, batch1.select(0, b), batch2.select(0, b), 0, 1); } } else { for (int64_t b = 0; b < bs; b++) { @@ -2105,7 +2110,7 @@ static std::vector make_dim_list(int64_t ndim) { } // Checks for valid arguments to linalg_norm when type(ord) == str -static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim) { +static void check_str_ord_valid(const c10::string_view str_ord, optional opt_dim, int64_t ndim) { TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); TORCH_CHECK(dims_valid, "order \"", str_ord, @@ -2206,7 +2211,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const return result; } -static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, const optional& opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) { +static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, const optional& opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) { // Callers must give the ord argument as either a number, a string, or neither. // Since the user-facing API has no direct control over how this function is called, this is an internal assert. TORCH_INTERNAL_ASSERT(!(opt_num_ord.has_value() && opt_str_ord.has_value())); @@ -2380,7 +2385,7 @@ Tensor& linalg_matrix_norm_out( Tensor linalg_matrix_norm( const Tensor& self, - std::string ord, + c10::string_view ord, IntArrayRef dim, bool keepdim, optional dtype) { @@ -2390,7 +2395,7 @@ Tensor linalg_matrix_norm( Tensor& linalg_matrix_norm_out( const Tensor& self, - std::string ord, + c10::string_view ord, IntArrayRef dim, bool keepdim, optional dtype, @@ -2408,7 +2413,7 @@ Tensor linalg_norm(const Tensor& self, const optional& opt_ord, optional } // Frobenius and nuclear norms -Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { +Tensor linalg_norm(const Tensor& self, c10::string_view ord, optional opt_dim, bool keepdim, optional opt_dtype) { auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device()); Tensor result = at::empty({0}, options); return at::native::linalg_norm_out( @@ -2421,12 +2426,12 @@ Tensor& linalg_norm_out(const Tensor& self, const optional& opt_ord, opt } // Frobenius and nuclear norms -Tensor& linalg_norm_out(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype, Tensor& result) { +Tensor& linalg_norm_out(const Tensor& self, c10::string_view ord, optional opt_dim, bool keepdim, optional opt_dtype, Tensor& result) { return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); } // This function helps to dispatch norm computations depending on 'ord' of variant type -Tensor _linalg_cond_helper(const Tensor& self, c10::variant ord_variant) { +Tensor _linalg_cond_helper(const Tensor& self, c10::variant ord_variant) { Tensor inverse, info; std::tie(inverse, info) = at::linalg_inv_ex(self); info.unsqueeze_(-1).unsqueeze_(-1); @@ -2449,14 +2454,14 @@ Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) { return at::zeros(result_shape, options); } -void _linalg_cond_check_ord(c10::variant ord_variant) { +void _linalg_cond_check_ord(c10::variant ord_variant) { if (ord_variant.index() == 0) { Scalar* ord = c10::get_if(&ord_variant); double abs_ord = std::abs(ord->toDouble()); TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY, "linalg_cond got an invalid norm type: ", ord->toDouble()); } else if (ord_variant.index() == 1) { - std::string* ord = c10::get_if(&ord_variant); + c10::string_view* ord = c10::get_if(&ord_variant); TORCH_CHECK(*ord == "fro" || *ord == "nuc", "linalg_cond got an invalid norm type: ", *ord); } else { @@ -2473,7 +2478,7 @@ Tensor linalg_cond(const Tensor& self, const optional& opt_ord) { // The default case is using 2-norm Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; - c10::variant ord_variant = ord; + c10::variant ord_variant = ord; _linalg_cond_check_ord(ord_variant); // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input @@ -2520,7 +2525,7 @@ Tensor& linalg_cond_out(const Tensor& self, const optional& opt_ord, Ten } // Frobenius or nuclear norms -Tensor linalg_cond(const Tensor& self, std::string ord) { +Tensor linalg_cond(const Tensor& self, c10::string_view ord) { // the same checks as squareCheckInputs(self) but with a slightly more informative error message TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", self.dim(), " dimensions."); @@ -2528,7 +2533,7 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); - c10::variant ord_variant = ord; + c10::variant ord_variant = ord; _linalg_cond_check_ord(ord_variant); // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input @@ -2548,7 +2553,7 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { } // TODO: implement _out variant avoiding copy and using already allocated storage directly -Tensor& linalg_cond_out(const Tensor& self, std::string ord, Tensor& result) { +Tensor& linalg_cond_out(const Tensor& self, c10::string_view ord, Tensor& result) { checkSameDevice("linalg_cond", result, self); ScalarType real_dtype = toValueType(self.scalar_type()); checkLinalgCompatibleDtype("linalg_cond", result.scalar_type(), real_dtype); @@ -2775,8 +2780,8 @@ std::tuple lu_unpack( .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(LU_pivots.sizes(), /*squash_dim=*/LU_pivots.dim() - 1) - .add_output(unpacked_pivots) - .add_input(LU_pivots_zero_idx) + .add_borrowed_output(unpacked_pivots) + .add_borrowed_input(LU_pivots_zero_idx) .build(); // } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index fbe0963d3c2b0b..0056bedec36e85 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -116,8 +116,8 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) - .add_output(b_linear_batch_idx) - .add_input(a_linear_batch_idx) + .add_borrowed_output(b_linear_batch_idx) + .add_borrowed_input(a_linear_batch_idx) .build(); auto m = a.size(-2); @@ -328,7 +328,7 @@ static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { } // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) -static inline std::tuple _parse_qr_mode(std::string mode) { +static inline std::tuple _parse_qr_mode(c10::string_view mode) { bool compute_q; bool reduced; if (mode == "reduced") { @@ -480,7 +480,7 @@ static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) { // This function checks whether the uplo argument input is valid // Allowed strings are "u", "U", "l", "L" -static inline void checkUplo(const std::string& uplo) { +static inline void checkUplo(const c10::string_view uplo) { // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char char uplo_uppercase = static_cast(std::toupper(static_cast(uplo[0]))); TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'), diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 9e24cf7f8c1496..43f4ad5095b4c8 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -103,9 +103,9 @@ Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor auto grad_expand = grad.expand_as(input); if (!log_target) { auto iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(target) - .add_input(grad_expand) + .add_borrowed_output(grad_input) + .add_borrowed_input(target) + .add_borrowed_input(grad_expand) .build(); AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "kl_div_backward_cpu", [&]() { cpu_serial_kernel(iter, [](scalar_t target_val, scalar_t grad_val) -> scalar_t{ @@ -141,7 +141,7 @@ Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target, Tensor loss_squeezed = at::squeeze(loss); auto iter = TensorIteratorConfig() - .add_output(loss_squeezed) + .add_borrowed_output(loss_squeezed) .add_input(at::squeeze(input)) .add_input(at::squeeze(target)) .build(); @@ -191,7 +191,7 @@ Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor& Tensor grad_input_squeezed = at::squeeze(grad_input); auto iter = TensorIteratorConfig() - .add_output(grad_input_squeezed) + .add_borrowed_output(grad_input_squeezed) .add_input(at::squeeze(grad)) .add_input(at::squeeze(input)) .add_input(at::squeeze(target)) @@ -332,7 +332,7 @@ Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t r return at::native::l1_loss(input, target, reduction); } Tensor loss; - auto iter = TensorIterator::binary_op(loss, input, target); + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); smooth_l1_stub(iter.device_type(), iter, beta); return apply_loss_reduction(iter.output(), reduction); } @@ -344,7 +344,7 @@ Tensor& smooth_l1_loss_out(const Tensor& input, const Tensor& target, int64_t re } if (reduction != Reduction::None) { Tensor loss; - auto iter = TensorIterator::binary_op(loss, input, target); + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); smooth_l1_stub(iter.device_type(), iter, beta); if (reduction == Reduction::Mean) { at::mean_out(result, iter.output(), 0); @@ -352,7 +352,7 @@ Tensor& smooth_l1_loss_out(const Tensor& input, const Tensor& target, int64_t re at::sum_out(result, iter.output(), 0); } } else { - auto iter = TensorIterator::binary_op(result, input, target); + auto iter = TensorIterator::borrowing_binary_op(result, input, target); smooth_l1_stub(iter.device_type(), iter, beta); } return result; @@ -364,10 +364,10 @@ Tensor& smooth_l1_loss_backward_out(const Tensor& grad_output, const Tensor& inp grad_output, input, target, reduction, grad_input); auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.; auto iter = at::TensorIteratorConfig() - .add_output(grad_input) - .add_input(input) - .add_input(target) - .add_input(grad_output) + .add_borrowed_output(grad_input) + .add_borrowed_input(input) + .add_borrowed_input(target) + .add_borrowed_input(grad_output) .build(); smooth_l1_backward_stub(iter.device_type(), iter, norm, beta); return grad_input; @@ -383,14 +383,14 @@ Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, c Tensor huber_loss(const Tensor& input, const Tensor& target, int64_t reduction, double delta) { TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") Tensor loss = at::empty_like(input); - auto iter = TensorIterator::binary_op(loss, input, target); + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); huber_stub(iter.device_type(), iter, delta); return apply_loss_reduction(loss, reduction); } Tensor& huber_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& result) { TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") - auto iter = TensorIterator::binary_op(result, input, target); + auto iter = TensorIterator::borrowing_binary_op(result, input, target); huber_stub(iter.device_type(), iter, delta); if (reduction != Reduction::None) { auto reduced = apply_loss_reduction(result, reduction); @@ -408,10 +408,10 @@ Tensor huber_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& huber_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& grad_input) { auto norm = (reduction == Reduction::Mean) ? (1. / input.numel()) : 1.; auto iter = at::TensorIteratorConfig() - .add_output(grad_input) - .add_input(input) - .add_input(target) - .add_input(grad_output) + .add_borrowed_output(grad_input) + .add_borrowed_input(input) + .add_borrowed_input(target) + .add_borrowed_input(grad_output) .build(); huber_backward_stub(iter.device_type(), iter, norm, delta); return grad_input; @@ -419,7 +419,7 @@ Tensor& huber_loss_backward_out(const Tensor& grad_output, const Tensor& input, Tensor mse_loss(const Tensor& input, const Tensor& target, int64_t reduction) { Tensor loss; - auto iter = TensorIterator::binary_op(loss, input, target); + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); mse_stub(iter.device_type(), iter); return apply_loss_reduction(iter.output(), reduction); } @@ -427,7 +427,7 @@ Tensor mse_loss(const Tensor& input, const Tensor& target, int64_t reduction) { Tensor& mse_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, Tensor&result) { if (reduction != Reduction::None) { Tensor loss; - auto iter = TensorIterator::binary_op(loss, input, target); + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); mse_stub(iter.device_type(), iter); if (reduction == Reduction::Mean) { at::mean_out(result, iter.output(), 0); @@ -435,8 +435,8 @@ Tensor& mse_loss_out(const Tensor& input, const Tensor& target, int64_t reductio at::sum_out(result, iter.output(), 0); } } else { - auto iter = TensorIterator::binary_op(result, input, target); - mse_stub(iter.device_type(), iter);; + auto iter = TensorIterator::borrowing_binary_op(result, input, target); + mse_stub(iter.device_type(), iter); } return result; } @@ -450,10 +450,10 @@ Tensor& mse_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) { auto norm = reduction == Reduction::Mean ? 2. / input.numel() : 2.; auto iter = at::TensorIteratorConfig() - .add_output(grad_input) - .add_input(input) - .add_input(target) - .add_input(grad_output) + .add_borrowed_output(grad_input) + .add_borrowed_input(input) + .add_borrowed_input(target) + .add_borrowed_input(grad_output) .build(); mse_backward_stub(iter.device_type(), iter, norm); return grad_input; diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index fd4cb8402948f7..f848aceb199561 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -202,12 +202,12 @@ std::tuple batch_norm_cpu_transform_input_template( Tensor output = at::empty(input.sizes(), input.options()); auto iter = TensorIteratorConfig() - .add_output(output) - .add_input(input) - .add_input(mean) - .add_input(invstd) - .add_input(w) - .add_input(b) + .add_borrowed_output(output) + .add_borrowed_input(input) + .add_borrowed_input(mean) + .add_borrowed_input(invstd) + .add_borrowed_input(w) + .add_borrowed_input(b) .build(); cpu_kernel(iter, [=](scalar_t input, scalar_t mean, scalar_t invstd, scalar_t weight, scalar_t bias) { @@ -248,7 +248,7 @@ std::tuple batch_norm_cpu_update_stats_template( // compute variance per input auto iter = TensorIteratorConfig() - .add_input(in) + .add_borrowed_input(in) .build(); accscalar_t var_sum = 0; auto mean = static_cast(save_mean_a[f]); @@ -335,8 +335,8 @@ std::tuple batch_norm_backward_cpu_template(const Tensor // dot product of the Q(X) and gradOuput accscalar_t dotp = 0; auto iter = TensorIteratorConfig() - .add_input(in) - .add_input(grad_out) + .add_borrowed_input(in) + .add_borrowed_input(grad_out) .build(); cpu_serial_kernel(iter, [&](const scalar_t i, const scalar_t go) -> void { dotp += (i - mean) * go; @@ -361,7 +361,7 @@ std::tuple batch_norm_backward_cpu_template(const Tensor scalar_t grad_mean = sum_a[f] / n; { - auto iter = TensorIterator::binary_op(grad_in, grad_in, grad_out); + auto iter = TensorIterator::borrowing_binary_op(grad_in, grad_in, grad_out); cpu_serial_kernel(iter, [&](scalar_t gi, scalar_t go) -> scalar_t { return (go - grad_mean - gi) * invstd * w; }); diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index 6893525fedc890..1bf75ac46f66c2 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -2,9 +2,10 @@ namespace at { namespace native { -Tensor one_hot(const Tensor &self, int64_t num_classes) { +Tensor one_hot(const Tensor &self, int64_t num_classes, ScalarType dtype) { TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor."); auto shape = self.sizes().vec(); + auto output_options = self.options().dtype(dtype); // empty tensor could be converted to one hot representation, // but shape inference is not possible. @@ -13,7 +14,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { AT_ERROR("Can not infer total number of classes from empty tensor."); } else { shape.push_back(num_classes); - return at::empty(shape, self.options()); + return at::empty(shape, output_options); } } @@ -35,7 +36,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { } shape.push_back(num_classes); - Tensor ret = at::zeros(shape, self.options()); + Tensor ret = at::zeros(shape, output_options); ret.scatter_(-1, self.unsqueeze(-1), 1); return ret; } diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp index bae2b670d15fbe..32349e1db3ac65 100644 --- a/aten/src/ATen/native/PointwiseOps.cpp +++ b/aten/src/ATen/native/PointwiseOps.cpp @@ -35,10 +35,10 @@ Tensor& addcmul_out(const Tensor& self, Tensor& result) { checkBackend("addcmul_cpu", result, self.options().backend()); auto iter = at::TensorIteratorConfig() - .add_output(result) - .add_input(self) - .add_input(tensor1) - .add_input(tensor2) + .add_borrowed_output(result) + .add_borrowed_input(self) + .add_borrowed_input(tensor1) + .add_borrowed_input(tensor2) .build(); addcmul_stub(iter.device_type(), iter, value); return result; @@ -80,10 +80,10 @@ Tensor& addcdiv_out(const Tensor& self, } checkBackend("addcdiv_cpu", result, self.options().backend()); auto iter = at::TensorIteratorConfig() - .add_output(result) - .add_input(self) - .add_input(tensor1) - .add_input(tensor2) + .add_borrowed_output(result) + .add_borrowed_input(self) + .add_borrowed_input(tensor1) + .add_borrowed_input(tensor2) .build(); addcdiv_stub(iter.device_type(), iter, value); return result; diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 836296f795cc08..3b7d26a67a872e 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -35,7 +35,7 @@ Tensor& linspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional result{true}; auto iter = TensorIteratorConfig() - .add_input(self) - .add_input(other) + .add_borrowed_input(self) + .add_borrowed_input(other) .allow_cpu_scalars(true) .promote_inputs_to_common_dtype(true) .build(); diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp index 758e7de8dde7e6..6ae54aad06b1bd 100644 --- a/aten/src/ATen/native/Repeat.cpp +++ b/aten/src/ATen/native/Repeat.cpp @@ -7,11 +7,16 @@ static void compute_cpu( index_t* repeat_ptr, int64_t* cumsum_ptr, index_t* result_ptr, - int64_t size) { + int64_t size, + int64_t result_size) { + TORCH_CHECK( + (result_size == cumsum_ptr[size - 1]), + "allocated size does not match required size"); at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) { for (int64_t i = i_begin; i < i_end; i++) { int64_t end = cumsum_ptr[i]; index_t size = repeat_ptr[i]; + TORCH_CHECK((size >= 0), "repeats can not be negative"); int64_t start = end - size; for (int64_t j = start; j < end; j++) { result_ptr[j] = i; @@ -23,10 +28,13 @@ static void compute_cpu( namespace at { namespace native { -Tensor repeat_interleave_cpu(const Tensor& repeat) { +Tensor repeat_interleave_cpu( + const Tensor& repeat, + c10::optional output_size) { Tensor output; AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() { - output = repeat_interleave_common>(repeat); + output = repeat_interleave_common>( + repeat, output_size); }); return output; @@ -35,7 +43,8 @@ Tensor repeat_interleave_cpu(const Tensor& repeat) { Tensor repeat_interleave( const Tensor& self, const Tensor& repeats, - c10::optional dim) { + c10::optional dim, + c10::optional output_size) { Tensor input = self; if (!dim) { input = self.flatten(); @@ -53,15 +62,20 @@ Tensor repeat_interleave( AT_ERROR("repeats must be 0-dim or 1-dim tensor"); } - return input.index_select(dim.value(), at::repeat_interleave(repeats_)); + return input.index_select( + dim.value(), at::repeat_interleave(repeats_, output_size)); } Tensor repeat_interleave( const Tensor& self, int64_t repeats, - c10::optional dim) { + c10::optional dim, + c10::optional output_size) { return at::native::repeat_interleave( - self, at::tensor({repeats}, self.options().dtype(kLong)), dim); + self, + at::tensor({repeats}, self.options().dtype(kLong)), + dim, + output_size); } } // namespace native diff --git a/aten/src/ATen/native/Repeat.h b/aten/src/ATen/native/Repeat.h index 3db3196127ec35..9751f2ec8be7a4 100644 --- a/aten/src/ATen/native/Repeat.h +++ b/aten/src/ATen/native/Repeat.h @@ -5,26 +5,36 @@ namespace at { namespace native { -template -static inline Tensor repeat_interleave_common(const Tensor& repeats) { +template < + typename index_t, + void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)> +static inline Tensor repeat_interleave_common( + const Tensor& repeats, + c10::optional output_size) { TORCH_CHECK( repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat"); TORCH_CHECK( repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt, "repeats has to be Long or Int tensor"); - TORCH_CHECK( - (repeats >= 0).all().item(), "repeats can not be negative"); if (repeats.size(0) == 0) { return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } Tensor repeats_ = repeats.contiguous(); Tensor cumsum = repeats.cumsum(0); - int64_t total = cumsum[-1].item(); + int64_t total; + if (output_size.has_value()) { + total = output_size.value(); + } else { + total = cumsum[-1].item(); + TORCH_CHECK( + (repeats >= 0).all().item(), "repeats can not be negative"); + } + Tensor result = at::empty({total}, repeats.options()); index_t* repeat_ptr = repeats_.data_ptr(); int64_t* cumsum_ptr = cumsum.data_ptr(); index_t* result_ptr = result.data_ptr(); - compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0)); + compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total); return result; } diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index d5cc96fe59b9c6..b1bb8a0cb040c4 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -99,14 +99,14 @@ Tensor _segment_reduce_cpu_backward_kernel( } // namespace enum SegmentReductionType { MAX }; -static const std::map segmentReduce2REDUCE = +static const std::map segmentReduce2REDUCE = { {"max", MAX}, }; Tensor segment_reduce_kernel( const Tensor& data, - std::string reduce, + c10::string_view reduce, const c10::optional& lengths, const c10::optional& indices, int64_t axis, diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index b85211b7d9f52f..dcd3cd8461d055 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -124,7 +124,7 @@ void quick_select_template( } QUANTILE_INTERPOLATION_MODE get_quantile_interpolation_mode( - const std::string& interpolation) { + const c10::string_view interpolation) { if (interpolation == "linear") { return QUANTILE_INTERPOLATION_MODE::LINEAR; } else if (interpolation == "lower") { @@ -472,7 +472,7 @@ Tensor& quantile_out( const Tensor& q, optional dim, bool keepdim, - const std::string interpolation, + const c10::string_view interpolation, Tensor& out) { quantile_impl( out, @@ -491,7 +491,7 @@ Tensor& quantile_out( double q, optional dim, bool keepdim, - const std::string interpolation, + const c10::string_view interpolation, Tensor& out) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); @@ -510,7 +510,7 @@ Tensor quantile( const Tensor& q, optional dim, bool keepdim, - const std::string interpolation) { + const c10::string_view interpolation) { Tensor out = at::empty({0}, self.options()); quantile_impl( out, @@ -529,7 +529,7 @@ Tensor quantile( double q, optional dim, bool keepdim, - const std::string interpolation) { + const c10::string_view interpolation) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); return at::native::quantile( @@ -542,7 +542,7 @@ Tensor& nanquantile_out( const Tensor& q, optional dim, bool keepdim, - const std::string interpolation, + const c10::string_view interpolation, Tensor& out) { quantile_impl( out, @@ -561,7 +561,7 @@ Tensor& nanquantile_out( double q, optional dim, bool keepdim, - const std::string interpolation, + const c10::string_view interpolation, Tensor& out) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); @@ -580,7 +580,7 @@ Tensor nanquantile( const Tensor& q, optional dim, bool keepdim, - const std::string interpolation) { + const c10::string_view interpolation) { Tensor out = at::empty({0}, self.options()); quantile_impl( out, @@ -599,7 +599,7 @@ Tensor nanquantile( double q, optional dim, bool keepdim, - const std::string interpolation) { + const c10::string_view interpolation) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); return at::native::nanquantile( diff --git a/aten/src/ATen/native/SortingUtils.h b/aten/src/ATen/native/SortingUtils.h index 3223fd3a779d19..c8a42135ceb2e6 100644 --- a/aten/src/ATen/native/SortingUtils.h +++ b/aten/src/ATen/native/SortingUtils.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace at { namespace native { @@ -61,7 +62,7 @@ inline void _reduction_with_indices_allocate_or_resize_output( // unsqueeze to preserve passed in noncontiguous tensor in resize values.unsqueeze_(dim); } - values.resize_(result_sizes); + resize_output(values, result_sizes); } else { values = at::empty(result_sizes, self.options()); } @@ -75,7 +76,7 @@ inline void _reduction_with_indices_allocate_or_resize_output( // unsqueeze to preserve passed in noncontiguous tensor in resize indices.unsqueeze_(dim); } - indices.resize_(result_sizes); + resize_output(indices, result_sizes); } else { indices = at::empty(result_sizes, self.options().dtype(kLong)); } diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index aca797d3360cd9..af7d8771021848 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -50,7 +50,7 @@ Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) { // Convert NumPy compatible normalization mode string to enum values // NOTE: NumPy's normalization modes have direction-specific meanings. For example, // "forward" translates to `by_n` for a forward transform and `none` for backward. -fft_norm_mode norm_from_string(c10::optional norm, bool forward) { +fft_norm_mode norm_from_string(c10::optional norm, bool forward) { if (!norm || *norm == "backward") { return forward ? fft_norm_mode::none : fft_norm_mode::by_n; } @@ -97,7 +97,7 @@ Tensor resize_fft_input(Tensor x, IntArrayRef dims, IntArrayRef sizes) { // Complex to real FFT Tensor fft_c2r(c10::string_view function_name, Tensor out, Tensor input, c10::optional n_opt, - int64_t unwrapped_dim, c10::optional norm_str, + int64_t unwrapped_dim, c10::optional norm_str, bool forward) { TORCH_CHECK(!out.defined() || out.is_floating_point(), function_name, " expects a floating point output tensor, but got ", out.scalar_type()); @@ -124,7 +124,7 @@ Tensor fft_c2r(c10::string_view function_name, // Real to complex FFT Tensor fft_r2c(c10::string_view function_name, Tensor out, Tensor input, c10::optional n_opt, - int64_t unwrapped_dim, c10::optional norm_str, + int64_t unwrapped_dim, c10::optional norm_str, bool forward, bool onesided) { TORCH_CHECK(!input.is_complex(), function_name, " expects a real input tensor, but got ", input.scalar_type()); @@ -159,7 +159,7 @@ Tensor fft_r2c(c10::string_view function_name, // Complex to complex FFT Tensor fft_c2c(c10::string_view function_name, Tensor out, Tensor input, c10::optional n_opt, - int64_t unwrapped_dim, c10::optional norm_str, + int64_t unwrapped_dim, c10::optional norm_str, bool forward) { TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got ", input.scalar_type()); @@ -255,7 +255,7 @@ ShapeAndDims canonicalize_fft_shape_and_dim_args( Tensor fftn_c2c( c10::string_view function_name, Tensor out, const Tensor& input, IntArrayRef shape, - IntArrayRef dim, c10::optional norm_str, bool forward) { + IntArrayRef dim, c10::optional norm_str, bool forward) { TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got", input.scalar_type()); Tensor x = resize_fft_input(input, dim, shape); const auto norm = norm_from_string(norm_str, forward); @@ -271,14 +271,14 @@ Tensor fftn_c2c( // torch.fft.fft, analogous to NumPy's numpy.fft.fft Tensor fft_fft(const Tensor& self, c10::optional n, int64_t dim, - c10::optional norm) { + c10::optional norm) { return self.is_complex() ? fft_c2c("fft", {}, self, n, dim, norm, /*forward=*/true) : fft_r2c("fft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/false); } Tensor& fft_fft_out(const Tensor& self, c10::optional n, - int64_t dim, c10::optional norm, Tensor& out) { + int64_t dim, c10::optional norm, Tensor& out) { if (self.is_complex()) { fft_c2c("fft", out, self, n, dim, norm, /*forward=*/true); } else { @@ -288,14 +288,14 @@ Tensor& fft_fft_out(const Tensor& self, c10::optional n, } Tensor fft_ifft(const Tensor& self, c10::optional n, int64_t dim, - c10::optional norm) { + c10::optional norm) { return self.is_complex() ? fft_c2c("ifft", {}, self, n, dim, norm, /*forward=*/false) : fft_r2c("ifft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/false); } Tensor& fft_ifft_out(const Tensor& self, c10::optional n, - int64_t dim, c10::optional norm, Tensor& out) { + int64_t dim, c10::optional norm, Tensor& out) { if (self.is_complex()) { fft_c2c("ifft", out, self, n, dim, norm, /*forward=*/false); } else { @@ -305,52 +305,52 @@ Tensor& fft_ifft_out(const Tensor& self, c10::optional n, } Tensor fft_rfft(const Tensor& self, c10::optional n, int64_t dim, - c10::optional norm) { + c10::optional norm) { return fft_r2c("rfft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/true); } Tensor& fft_rfft_out(const Tensor& self, c10::optional n, - int64_t dim, c10::optional norm, Tensor& out) { + int64_t dim, c10::optional norm, Tensor& out) { fft_r2c("rfft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/true); return out; } Tensor fft_irfft(const Tensor& self, c10::optional n, int64_t dim, - c10::optional norm) { + c10::optional norm) { return fft_c2r("irfft", {}, self, n, dim, norm, /*forward=*/false); } Tensor& fft_irfft_out(const Tensor& self, c10::optional n, - int64_t dim, c10::optional norm, Tensor& out) { + int64_t dim, c10::optional norm, Tensor& out) { fft_c2r("irfft", out, self, n, dim, norm, /*forward=*/false); return out; } Tensor fft_hfft(const Tensor& self, c10::optional n, int64_t dim, - c10::optional norm) { + c10::optional norm) { return fft_c2r("hfft", {}, self, n, dim, norm, /*forward=*/true); } Tensor& fft_hfft_out(const Tensor& self, c10::optional n, - int64_t dim, c10::optional norm, Tensor& out) { + int64_t dim, c10::optional norm, Tensor& out) { fft_c2r("hfft", out, self, n, dim, norm, /*forward=*/true); return out; } Tensor fft_ihfft(const Tensor& self, c10::optional n, int64_t dim, - c10::optional norm) { + c10::optional norm) { return fft_r2c("ihfft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/true); } Tensor& fft_ihfft_out(const Tensor& self, c10::optional n, - int64_t dim, c10::optional norm, Tensor& out) { + int64_t dim, c10::optional norm, Tensor& out) { fft_r2c("ihfft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/true); return out; } Tensor fft_fftn(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm) { + c10::optional norm) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); // TODO: For real input, perform rfftn then mirror with conjugate symmetry Tensor input = promote_tensor_fft(self, /*require_complex=*/true); @@ -360,7 +360,7 @@ Tensor fft_fftn(const Tensor& self, c10::optional s, Tensor& fft_fftn_out(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm, Tensor& out) { + c10::optional norm, Tensor& out) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); // TODO: For real input, perform rfftn then mirror with conjugate symmetry Tensor input = promote_tensor_fft(self, /*require_complex=*/true); @@ -370,7 +370,7 @@ Tensor& fft_fftn_out(const Tensor& self, Tensor fft_ifftn(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm) { + c10::optional norm) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); Tensor input = promote_tensor_fft(self, /*require_complex=*/true); return fftn_c2c("ifftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/false); @@ -379,7 +379,7 @@ Tensor fft_ifftn(const Tensor& self, c10::optional s, Tensor& fft_ifftn_out(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm, Tensor& out) { + c10::optional norm, Tensor& out) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); Tensor input = promote_tensor_fft(self, /*require_complex=*/true); fftn_c2c("ifftn", out, input, desc.shape, desc.dim, norm, /*forward=*/false); @@ -389,7 +389,7 @@ Tensor& fft_ifftn_out(const Tensor& self, static Tensor fft_rfftn_impl(Tensor out, const Tensor& self, c10::optional s, c10::optional dim, - const c10::optional& norm_str) { + const c10::optional& norm_str) { TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type()); auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis"); @@ -406,14 +406,14 @@ static Tensor fft_rfftn_impl(Tensor out, const Tensor& self, Tensor fft_rfftn(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm_str) { + c10::optional norm_str) { return fft_rfftn_impl({}, self, s, dim, norm_str); } Tensor& fft_rfftn_out(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm_str, Tensor& out) { + c10::optional norm_str, Tensor& out) { fft_rfftn_impl(out, self, s, dim, norm_str); return out; } @@ -421,7 +421,7 @@ Tensor& fft_rfftn_out(const Tensor& self, static Tensor fft_irfftn_impl(Tensor out, const Tensor& self, c10::optional s, c10::optional dim, - const c10::optional& norm_str) { + const c10::optional& norm_str) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis"); @@ -449,55 +449,55 @@ static Tensor fft_irfftn_impl(Tensor out, const Tensor& self, Tensor fft_irfftn(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm_str) { + c10::optional norm_str) { return fft_irfftn_impl({}, self, s, dim, norm_str); } Tensor& fft_irfftn_out(const Tensor& self, c10::optional s, c10::optional dim, - c10::optional norm_str, Tensor& out) { + c10::optional norm_str, Tensor& out) { fft_irfftn_impl(out, self, s, dim, norm_str); return out; } Tensor fft_fft2(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm) { + IntArrayRef dim, c10::optional norm) { return native::fft_fftn(self, s, dim, std::move(norm)); } Tensor& fft_fft2_out(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm, Tensor& out) { + IntArrayRef dim, c10::optional norm, Tensor& out) { return native::fft_fftn_out(self, s, dim, std::move(norm), out); } Tensor fft_ifft2(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm) { + IntArrayRef dim, c10::optional norm) { return native::fft_ifftn(self, s, dim, std::move(norm)); } Tensor& fft_ifft2_out(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm, Tensor& out) { + IntArrayRef dim, c10::optional norm, Tensor& out) { return native::fft_ifftn_out(self, s, dim, std::move(norm), out); } Tensor fft_rfft2(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm) { + IntArrayRef dim, c10::optional norm) { return native::fft_rfftn(self, s, dim, std::move(norm)); } Tensor& fft_rfft2_out(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm, Tensor& out) { + IntArrayRef dim, c10::optional norm, Tensor& out) { return native::fft_rfftn_out(self, s, dim, std::move(norm), out); } Tensor fft_irfft2(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm) { + IntArrayRef dim, c10::optional norm) { return native::fft_irfftn(self, s, dim, std::move(norm)); } Tensor& fft_irfft2_out(const Tensor& self, c10::optional s, - IntArrayRef dim, c10::optional norm, Tensor& out) { + IntArrayRef dim, c10::optional norm, Tensor& out) { return native::fft_irfftn_out(self, s, dim, std::move(norm), out); } @@ -984,8 +984,8 @@ void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { // Use TensorIterator to coalesce batch dimensions // NOTE: Can't use TensorIterator loops because we need negative strides auto iter = TensorIteratorConfig() - .add_output(input) - .add_input(input) + .add_borrowed_output(input) + .add_borrowed_input(input) .resize_outputs(false) .declare_static_shape(input_sizes, dim) .build(); diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 04fb4b24769129..b1b1e83b4bf9fb 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -999,14 +999,6 @@ Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, cons if (sparse_grad) { return at::_gather_sparse_backward(self, dim, index, grad); } - if (globalContext().deterministicAlgorithms() && index.dim() == 1 && self.dim() == 1){ - TORCH_CHECK(index.numel() == grad.numel(), "index and grad should have same number of elements, " - "but got ", index.numel(), " versus ", grad.numel()); - torch::List> indices; - indices.reserve(1); - indices.push_back(index); - return at::zeros(self.sizes(), grad.options()).index_put_(indices, grad, true); - } return at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad); } @@ -1029,7 +1021,7 @@ Tensor & scatter_fill_(Tensor & self, int64_t dim, const Tensor & index, const S return self; } -SCATTER_GATHER_OP get_operator_enum(const std::string& reduce) { +SCATTER_GATHER_OP get_operator_enum(const c10::string_view reduce) { if (reduce == "add") { return SCATTER_GATHER_OP::REDUCE_ADD; } @@ -1043,7 +1035,7 @@ SCATTER_GATHER_OP get_operator_enum(const std::string& reduce) { } Tensor& scatter_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& index, - const Scalar& value, const std::string reduce) { + const Scalar& value, c10::string_view reduce) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), @@ -1056,7 +1048,7 @@ Tensor& scatter_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& in } Tensor & scatter_reduce_(Tensor & self, const int64_t dim, const Tensor & index, - const Tensor & src, const std::string reduce) { + const Tensor & src, c10::string_view reduce) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index"); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), @@ -1083,6 +1075,17 @@ Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Te at::assert_no_internal_overlap(self); at::assert_no_overlap(self, index); at::assert_no_overlap(self, src); + if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA && self.dim() == 1){ + TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, " + "but their dims are ", index.dim(), " and ", src.dim(), ", respectively"); + TORCH_CHECK(index.numel() == src.numel(), "index and src should have same number of elements for 1D tensors, " + "but got ", index.numel(), " versus ", src.numel()); + TORCH_CHECK(dim == 0, "dim should be zero for 1D self tensor, but got ", dim); + torch::List> indices; + indices.reserve(1); + indices.push_back(index); + return self.index_put_(indices, src, true); + } scatter_add_stub(self.device().type(), self, dim, index, src); return self; } diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index edc4ae22f5aeb1..0ad5ab70382024 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -134,8 +134,8 @@ Tensor& isposinf_out(const Tensor& self, Tensor& result) { } else { auto iter = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(result) - .add_input(self) + .add_borrowed_output(result) + .add_borrowed_input(self) .build(); isposinf_stub(iter.device_type(), iter); } @@ -158,8 +158,8 @@ Tensor& isneginf_out(const Tensor& self, Tensor& result) { } else { auto iter = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(result) - .add_input(self) + .add_borrowed_output(result) + .add_borrowed_input(self) .build(); isneginf_stub(iter.device_type(), iter); } @@ -286,10 +286,10 @@ Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other Tensor ret = at::empty(self.sizes(), self.options()); auto iter = at::TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(ret) - .add_input(condition) - .add_input(self) - .add_input(other) + .add_borrowed_output(ret) + .add_borrowed_input(condition) + .add_borrowed_input(self) + .add_borrowed_input(other) .build(); where_kernel(iter.device_type(), iter, condition.scalar_type()); return ret; @@ -503,10 +503,10 @@ Tensor& clamp_out(const Tensor& self, const c10::optional& min, "torch.clamp only supports strided layout, got: ", self.layout()); auto iter = TensorIteratorConfig() .set_check_mem_overlap(true) - .add_output(result) - .add_input(self) - .add_input(*min) - .add_input(*max) + .add_borrowed_output(result) + .add_borrowed_input(self) + .add_borrowed_input(*min) + .add_borrowed_input(*max) .promote_inputs_to_common_dtype(true) .cast_common_dtype_to_outputs(true) .enforce_safe_casting_to_output(true) @@ -549,7 +549,7 @@ Tensor& clamp_max_out(const Tensor& self, const Scalar& max, Tensor& result) { Tensor& clamp_max_out(const Tensor& self, const Tensor& max, Tensor& result) { TORCH_CHECK(self.layout() == Layout::Strided, "torch.clamp only supports strided layout, got: ", self.layout()); - auto iter = TensorIterator::binary_op(result, self, max); + auto iter = TensorIterator::borrowing_binary_op(result, self, max); clamp_max_stub(iter.device_type(), iter); return result; } @@ -581,7 +581,7 @@ Tensor& clamp_min_out(const Tensor& self, const Scalar& min, Tensor& result) { Tensor& clamp_min_out(const Tensor& self, const Tensor& min, Tensor& result) { TORCH_CHECK(self.layout() == Layout::Strided, "torch.clamp only supports strided layout, got: ", self.layout()); - auto iter = TensorIterator::binary_op(result, self, min); + auto iter = TensorIterator::borrowing_binary_op(result, self, min); clamp_min_stub(iter.device_type(), iter); return result; } diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index b66070a6776d73..57a0a2dfc840c9 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -133,9 +133,9 @@ void complex_check_dtype( Tensor& complex_out(const Tensor& real, const Tensor& imag, Tensor& result) { complex_check_dtype(result, real, imag); auto iter = TensorIteratorConfig() - .add_output(result) - .add_input(real) - .add_input(imag) + .add_borrowed_output(result) + .add_borrowed_input(real) + .add_borrowed_input(imag) .check_all_same_dtype(false) .build(); complex_stub(iter.device_type(), iter); @@ -153,9 +153,9 @@ Tensor complex(const Tensor& real, const Tensor& imag) { Tensor& polar_out(const Tensor& abs, const Tensor& angle, Tensor& result) { complex_check_dtype(result, abs, angle); auto iter = TensorIteratorConfig() - .add_output(result) - .add_input(abs) - .add_input(angle) + .add_borrowed_output(result) + .add_borrowed_input(abs) + .add_borrowed_input(angle) .check_all_same_dtype(false) .build(); polar_stub(iter.device_type(), iter); @@ -1383,7 +1383,7 @@ Tensor tensor_complex_backend(ArrayRef values, const TensorOptions& options) return at::detail::tensor_complex_backend(values, options); } -Tensor from_file(std::string filename, c10::optional shared, c10::optional size, +Tensor from_file(c10::string_view filename, c10::optional shared, c10::optional size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -1400,7 +1400,7 @@ Tensor from_file(std::string filename, c10::optional shared, c10::optional c10::StorageImpl::use_byte_size_t(), size_bytes, THMapAllocator::makeDataPtr( - filename.c_str(), flags, size_bytes, nullptr), + std::string(filename), flags, size_bytes, nullptr), /*allocator=*/nullptr, /*resizable=*/false); auto tensor = detail::make_tensor( diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index abdadc73dca361..28ba89255e9c57 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -226,8 +226,8 @@ Tensor & _cat_out_cpu(TensorList tensors, int64_t dim, Tensor& result) { auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // Already checked above .resize_outputs(false) - .add_output(result_slice) - .add_input(source_slice) + .add_borrowed_output(result_slice) + .add_borrowed_input(source_slice) .enforce_safe_casting_to_output(true) .build(); @@ -253,8 +253,8 @@ Tensor & _cat_out_cpu(TensorList tensors, int64_t dim, Tensor& result) { auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // Already checked above .resize_outputs(false) - .add_output(result_slice) - .add_input(tensor) + .add_borrowed_output(result_slice) + .add_borrowed_input(tensor) .promote_inputs_to_common_dtype(true) .cast_common_dtype_to_outputs(true) .enforce_safe_casting_to_output(true) diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index d5b00e4f8cb414..cd9ba41ada7383 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -46,7 +46,7 @@ std::vector build_indices_loop(Tensor input, IntArrayRef flip_dims) { return indices; } -static TensorIterator make_index_iterator(const Tensor input, const std::vector indices) { +static TensorIterator make_index_iterator(const Tensor& input, const std::vector indices) { TensorIteratorConfig config; auto output_tensor = Tensor(); @@ -61,9 +61,9 @@ static TensorIterator make_index_iterator(const Tensor input, const std::vector< .check_all_same_dtype(false) .declare_static_dtype_and_device(input.scalar_type(), input.device()) .add_output(output_tensor) - .add_input(input); + .add_borrowed_input(input); for (auto& index : indices) { - config.add_input(index); + config.add_borrowed_input(index); } return config.build(); } diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index 53e5bc4ca1b60d..0b7a046637bf1b 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -46,7 +46,7 @@ Tensor _test_optional_floatlist( } // Test default strings can handle escape sequences properly (although commas are broken) -Tensor _test_string_default(const Tensor& dummy, std::string a, std::string b) { +Tensor _test_string_default(const Tensor& dummy, c10::string_view a, c10::string_view b) { const c10::string_view expect = "\"'\\"; TORCH_CHECK(a == expect, "Default A failed"); TORCH_CHECK(b == expect, "Default B failed"); @@ -64,7 +64,7 @@ Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) { } // Overload b -Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b) { +Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, c10::string_view b) { TORCH_CHECK(a == 2); TORCH_CHECK(b == "2"); return c10::scalar_to_tensor(2); diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index bf20acb8fcd47e..6a13af13b37f31 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -49,6 +49,7 @@ CREATE_UNARY_FLOAT_META_FUNC(erfinv) CREATE_UNARY_FLOAT_META_FUNC(exp) CREATE_UNARY_FLOAT_META_FUNC(exp2) CREATE_UNARY_FLOAT_META_FUNC(expm1) +CREATE_UNARY_FLOAT_META_FUNC(i0) CREATE_UNARY_FLOAT_META_FUNC(lgamma) CREATE_UNARY_FLOAT_META_FUNC(log) CREATE_UNARY_FLOAT_META_FUNC(log10) @@ -78,7 +79,6 @@ TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) { } CREATE_UNARY_META_FUNC(bitwise_not) CREATE_UNARY_META_FUNC(frac) -CREATE_UNARY_META_FUNC(i0) CREATE_UNARY_META_FUNC(round) CREATE_UNARY_META_FUNC(sgn) @@ -531,8 +531,8 @@ Tensor& logical_not_(Tensor& self) { Tensor& logical_not_out(const Tensor& self, Tensor& result) { TensorIterator iter = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(result) - .add_input(self) + .add_borrowed_output(result) + .add_borrowed_input(self) .build(); logical_not_stub(iter.device_type(), iter); return result; @@ -548,8 +548,8 @@ Tensor& signbit_out(const Tensor& self, Tensor& result) { } else { TensorIterator iter = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(result) - .add_input(self) + .add_borrowed_output(result) + .add_borrowed_input(self) .build(); signbit_stub(iter.device_type(), iter); } @@ -627,9 +627,9 @@ std::tuple frexp_out(const Tensor& self, "but got ", exponent.dtype()); auto iter = TensorIteratorConfig() - .add_output(mantissa) - .add_output(exponent) - .add_input(self) + .add_borrowed_output(mantissa) + .add_borrowed_output(exponent) + .add_borrowed_input(self) .check_all_same_dtype(false) .set_check_mem_overlap(true) .build(); diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 5f79519a610cee..e26768d5081986 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -226,7 +226,7 @@ void elu_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale }); } -void elu_backward_kernel(TensorIterator& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) { +void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) { AT_DISPATCH_FLOATING_TYPES(it.dtype(), "elu_backward_cpu", [&]() { using Vec = Vectorized; auto negcoef = alpha.to() * scale.to(); @@ -346,7 +346,7 @@ void hardsigmoid_kernel(TensorIteratorBase& iter) { }); } -void hardsigmoid_backward_kernel(TensorIterator& iter) { +void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardsigmoid_backward", [&] { const scalar_t zero(0.0f); const scalar_t three(3.0f); @@ -503,7 +503,7 @@ static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { }); } -static void leaky_relu_backward_kernel(TensorIterator& iter, const Scalar& negval_) { +static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&] { using Vec = Vectorized; auto zero_vec = Vec((scalar_t)(0)); @@ -542,7 +542,7 @@ void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar }); } -void softplus_backward_kernel(TensorIterator& iter, const Scalar& beta_, const Scalar& threshold_) { +void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softplus_backward_cpu", [&]() { using Vec = Vectorized; auto beta = beta_.to(); @@ -632,6 +632,40 @@ void silu_backward_kernel(TensorIterator& iter) { }); } +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() { + using Vec = Vectorized; + cpu_kernel_vec( + iter, + [](scalar_t x) -> scalar_t{ + return static_cast(x * std::tanh(std::log1p(std::exp(x)))); + }, + [](Vec x_vec) -> Vec { + return x_vec * x_vec.exp().log1p().tanh(); + }); + }); +} + +void mish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() { + using Vec = Vectorized; + const Vec kOneVec(scalar_t(1)); + cpu_kernel_vec( + iter, + [](scalar_t dy, scalar_t x) -> scalar_t { + const scalar_t sigmoid = + scalar_t(1) / (scalar_t(1) + std::exp(-x)); + const scalar_t tanh_softplus = std::tanh(std::log1p(std::exp(x))); + return dy * (tanh_softplus + x * sigmoid * (scalar_t(1) - tanh_softplus * tanh_softplus)); + }, + [kOneVec](Vec dy_vec, Vec x_vec) -> Vec { + const Vec sigmoid = kOneVec / (kOneVec + x_vec.neg().exp()); + const Vec tanh_softplus = x_vec.exp().log1p().tanh(); + return dy_vec * (tanh_softplus + x_vec * sigmoid * (kOneVec - tanh_softplus * tanh_softplus)); + }); + }); +} + } // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -680,6 +714,10 @@ REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel); REGISTER_DISPATCH(silu_stub, &silu_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_DISPATCH(mish_stub, &mish_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 3e2a359d4757b9..4199dee6911853 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -210,7 +210,7 @@ void div_floor_kernel(TensorIteratorBase& iter) { } } -void remainder_kernel(TensorIterator& iter) { +void remainder_kernel(TensorIteratorBase& iter) { if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h index fdbb67227913e3..f3ead0b9eff595 100644 --- a/aten/src/ATen/native/cpu/DistributionTemplates.h +++ b/aten/src/ATen/native/cpu/DistributionTemplates.h @@ -184,7 +184,7 @@ void normal_kernel(Tensor& self, double mean, double std, RNG generator) { if (size >= 16 && self.is_contiguous()) { normal_fill(self, static_cast(mean), static_cast(std), generator); } else { - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); std::lock_guard lock(generator->mutex_); cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { at::normal_distribution normal(mean, std); @@ -315,8 +315,8 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) { auto p_cpu = p_.to(kCPU); c10::MaybeOwned p = expand_inplace(self, p_cpu); auto iter = TensorIteratorConfig() - .add_output(self) - .add_input(*p) + .add_borrowed_output(self) + .add_borrowed_input(*p) .check_all_same_dtype(false) .build(); if (p_.scalar_type() == kDouble) { @@ -341,7 +341,7 @@ void bernoulli_kernel(Tensor& self, double p, RNG generator) { AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] { // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); cpu_serial_kernel(iter, [p, generator]() -> scalar_t { at::bernoulli_distribution bernoulli(p); return static_cast(bernoulli(generator)); diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp index 480dca26e15c65..02ef8150fe1bdc 100644 --- a/aten/src/ATen/native/cpu/LerpKernel.cpp +++ b/aten/src/ATen/native/cpu/LerpKernel.cpp @@ -15,7 +15,7 @@ static void lerp_kernel_scalar( const Tensor& end, const Scalar& weight) { TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); - auto iter = TensorIterator::binary_op(ret, self, end); + auto iter = TensorIterator::borrowing_binary_op(ret, self, end); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] { using value_t = typename c10::scalar_value_type::type; scalar_t weight_val = weight.to(); @@ -37,10 +37,10 @@ static void lerp_kernel_tensor( TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); TORCH_CHECK(self.dtype() == weights.dtype(), "expected dtype ", self.dtype(), " for `weights` but got dtype ", weights.dtype()); auto iter = TensorIteratorConfig() - .add_output(ret) - .add_input(self) - .add_input(end) - .add_input(weights) + .add_borrowed_output(ret) + .add_borrowed_input(self) + .add_borrowed_input(end) + .add_borrowed_input(weights) .build(); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(ret.scalar_type(), "lerp_kernel_tensor", [&] { using value_t = typename c10::scalar_value_type::type; diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index a642a361b0ad27..9e9700da018dbb 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -258,14 +258,9 @@ void cpu_kernel(TensorIteratorBase& iter, func_t&& op) { TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); iter.for_each([&](char** data, const int64_t* strides, int64_t n) { - if (is_contiguous(strides)) { + // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that + // iter.for_each is ever sending to the loop lambda basic_loop(data, strides, 0, n, std::forward(op)); - } else { - using Indices = std::make_index_sequence; - unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t _idx) { - basic_loop(data, strides, 0, n, std::forward(op)); - }); - } }); iter.cast_outputs(); } @@ -285,14 +280,7 @@ void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op) { TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); iter.for_each([&](char** data, const int64_t* strides, int64_t n) { - if (is_contiguous(strides)) { - multiple_outputs_loop(data, strides, 0, n, std::forward(op)); - } else { - using Indices = std::make_index_sequence; - unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t _idx) { - multiple_outputs_loop(data, strides, 0, n, std::forward(op)); - }); - } + multiple_outputs_loop(data, strides, 0, n, std::forward(op)); }); iter.cast_outputs(); } @@ -336,14 +324,7 @@ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) { - if (is_contiguous(strides)) { - basic_loop(data, strides, 0, n, std::forward(op)); - } else { - using Indices = std::make_index_sequence; - unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t _idx) { - basic_loop(data, strides, 0, n, std::forward(op)); - }); - } + basic_loop(data, strides, 0, n, std::forward(op)); }, range); iter.cast_outputs(); } diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index 8f0a5d7705f2aa..66b92ecde1d087 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -61,7 +61,7 @@ inline void reduce_all_impl( static void min_all_kernel_impl(Tensor& result, const Tensor& input) { if (input.scalar_type() == ScalarType::Bool) { TensorIterator iter = TensorIteratorConfig() - .add_input(input) + .add_borrowed_input(input) .build(); bool result_data = true; cpu_serial_kernel(iter, [&](const bool a) -> void { @@ -86,7 +86,7 @@ static void min_all_kernel_impl(Tensor& result, const Tensor& input) { static void max_all_kernel_impl(Tensor& result, const Tensor& input) { if (input.scalar_type() == ScalarType::Bool) { TensorIterator iter = TensorIteratorConfig() - .add_input(input) + .add_borrowed_input(input) .build(); bool result_data = false; cpu_serial_kernel(iter, [&](const bool a) -> void { @@ -167,7 +167,7 @@ static void _aminmax_all_kernel_impl(Tensor& min_result, Tensor& max_result, const Tensor& input) { if (input.scalar_type() == ScalarType::Bool) { TensorIterator iter = TensorIteratorConfig() - .add_input(input) + .add_borrowed_input(input) .build(); bool min_result_data = true; bool max_result_data = false; diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 2604c9f1b9aedd..0ea06caf8e7b66 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -43,8 +43,8 @@ static inline void cpu_cum_base_kernel(Tensor& result, .resize_outputs(false) // NOLINTNEXTLINE(bugprone-argument-comment) .declare_static_shape(self.sizes(), /*squash_dim=*/dim) - .add_output(result) - .add_input(self) + .add_borrowed_output(result) + .add_borrowed_input(self) .build(); auto result_dim_stride = ensure_nonempty_stride(result, dim); diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index 2be36547f78b08..3a4ec42394ac85 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -135,8 +135,8 @@ struct cpu_scatter_gather_base_kernel { .resize_outputs(false) // NOLINTNEXTLINE(bugprone-argument-comment) .declare_static_shape(index.sizes(), /*squash_dim=*/dim) - .add_output(self) - .add_input(index) + .add_borrowed_output(self) + .add_borrowed_input(index) .build(); auto self_dim_stride = ensure_nonempty_stride(self, dim); @@ -227,9 +227,9 @@ struct cpu_scatter_gather_base_kernel { .resize_outputs(false) // NOLINTNEXTLINE(bugprone-argument-comment) .declare_static_shape(index.sizes(), /*squash_dim=*/dim) - .add_output(self) - .add_input(src) - .add_input(index) + .add_borrowed_output(self) + .add_borrowed_input(src) + .add_borrowed_input(index) .build(); auto self_dim_stride = ensure_nonempty_stride(self, dim); diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index ffb52fe75cb3b6..c3f63ae60fdf7c 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -41,8 +41,8 @@ void _dim_apply( .resize_outputs(false) // NOLINTNEXTLINE(bugprone-argument-comment) .declare_static_shape(values.sizes(), /*squash_dim=*/dim) - .add_output(values) - .add_output(indices) + .add_borrowed_output(values) + .add_borrowed_output(indices) .build(); auto values_dim_stride = values.stride(dim); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 48eacf6533184b..29005243a0ad13 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -46,9 +46,9 @@ static inline void compare_base_kernel_core( .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(self.sizes(), /*squash_dims=*/dim) - .add_output(result1) - .add_output(result2) - .add_input(self) + .add_borrowed_output(result1) + .add_borrowed_output(result2) + .add_borrowed_input(self) .build(); iter.for_each(loop, /* grain_size */ 1); diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 949437db0c85cf..0205cb4c6bd928 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -758,12 +758,12 @@ void upsample_generic_Nd_kernel_impl( TensorIteratorConfig config; config.check_all_same_dtype(false) .declare_static_dtype_and_device(input.scalar_type(), input.device()) - .add_output(output) - .add_input(restrided_input); + .add_borrowed_output(output) + .add_borrowed_input(restrided_input); for (auto & idx_weight: indices_weights) { for (auto& tensor : idx_weight) { - config.add_input(tensor); + config.add_borrowed_input(tensor); } } diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 84be955baeb5bc..4ecf7fe00d7127 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -293,7 +293,7 @@ void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar }); } -void softplus_backward_kernel(TensorIterator& iter, const Scalar& beta_, const Scalar& threshold_) { +void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_backward_cuda", [&]() { auto beta = beta_.to(); auto threshold = threshold_.to(); @@ -328,7 +328,7 @@ void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& sca }); } -void elu_backward_kernel(TensorIterator& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) { +void elu_backward_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() { auto negcoef = alpha.to() * scale.to(); auto poscoef = scale.to(); @@ -380,7 +380,7 @@ void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { }); } -void leaky_relu_backward_kernel(TensorIterator& iter, const Scalar& negval_) { +void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_backward_cuda", [&]() { auto negval = negval_.to(); gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { @@ -440,7 +440,7 @@ void hardsigmoid_kernel(TensorIteratorBase& iter) { }); } -void hardsigmoid_backward_kernel(TensorIterator& iter) { +void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_backward_cuda", [&]() { using T_ACC = acc_type; const T_ACC zero(0.0f); @@ -496,6 +496,45 @@ void silu_backward_kernel(TensorIterator& iter) { }); } +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_cuda", + [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using T_ACC = acc_type; + const T_ACC x_acc = static_cast(x); + return x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); + }); + }); +} + +void mish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_backward_cuda", + [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using T_ACC = acc_type; + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + const T_ACC s_acc = + T_ACC(1) / (T_ACC(1) + c10::cuda::compat::exp(-x_acc)); + const T_ACC t_acc = + c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); + return dy_acc * (t_acc + x_acc * s_acc * (T_ACC(1) - t_acc * t_acc)); + }); + }); +} + } // namespace Tensor gelu_cuda(const Tensor& self) { @@ -540,6 +579,8 @@ REGISTER_DISPATCH(softplus_stub, &softplus_kernel); REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); REGISTER_DISPATCH(silu_stub, &silu_kernel); REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); +REGISTER_DISPATCH(mish_stub, &mish_kernel); +REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda); } // namespace native diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu index 908a8566bd7d62..a5d8a643648e74 100644 --- a/aten/src/ATen/native/cuda/AmpKernels.cu +++ b/aten/src/ATen/native/cuda/AmpKernels.cu @@ -113,6 +113,7 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, // - all scaled_grads are strided // - all scaled_grads are non overlapping and dense // - all scaled_grads are on the same device + // - all scaled_grads are of the same dtype TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors."); // Sets up MTA launch to use scaled_grads as-is. tensor_lists.emplace_back(scaled_grads.vec()); @@ -126,12 +127,13 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, tensor_lists.resize(1); tensor_lists[0].reserve(scaled_grads.size()); auto expected_device = scaled_grads[0].device(); + const auto expected_dtype = scaled_grads[0].scalar_type(); for (const Tensor& t : scaled_grads) { // Ensures GradScaler filtered scaled_grads by device. TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor."); TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); - if (!t.is_non_overlapping_and_dense()) { + if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) { // t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel. _amp_non_finite_check_and_unscale_cuda_(const_cast(t), found_inf, diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 0170c25fc94d7d..6a585907f7df0e 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2188,7 +2188,7 @@ AT_ERROR("qr: MAGMA library not found in " #endif } -std::tuple linalg_qr_helper_magma(const Tensor& self, std::string mode) { +std::tuple linalg_qr_helper_magma(const Tensor& self, c10::string_view mode) { bool compute_q, reduced; std::tie(compute_q, reduced) = _parse_qr_mode(mode); @@ -2238,7 +2238,7 @@ std::tuple linalg_qr_helper_magma(const Tensor& self, std::strin return std::make_tuple(q_working_copy, r_working_copy); } -std::tuple _linalg_qr_helper_cuda(const Tensor& input, std::string mode) { +std::tuple _linalg_qr_helper_cuda(const Tensor& input, c10::string_view mode) { #if defined(USE_CUSOLVER) // _linalg_qr_helper_default is a generic function that is implemented using // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined diff --git a/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu b/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu index 9b5cc2ce6ad1df..90fb4c53de61c4 100644 --- a/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu @@ -11,7 +11,7 @@ namespace at { namespace native { -void remainder_kernel_cuda(TensorIterator& iter) { +void remainder_kernel_cuda(TensorIteratorBase& iter) { if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cuda", [&]() { gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 8134fb87a4d276..b0fe0ac7a05b6f 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -260,14 +260,8 @@ TORCH_IMPL_FUNC(addmm_out_cuda)(const Tensor& self, const Tensor& mat1, const Te addmm_out_cuda_impl(const_cast(result), self, mat1, mat2, beta, alpha); } -Tensor& mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) { - result.resize_({ self.size(0), mat2.size(1) }); - return addmm_out_cuda_impl(result, result, self, mat2, 0, 1); -} - -Tensor mm_cuda(const Tensor& self, const Tensor& mat2) { - Tensor result = at::empty({ self.size(0), mat2.size(1) }, self.options()); - return addmm_out_cuda_impl(result, result, self, mat2, 0, 1); +TORCH_IMPL_FUNC(mm_out_cuda)(const Tensor& self, const Tensor& mat2, const Tensor& result) { + addmm_out_cuda_impl(const_cast(result), result, self, mat2, 0, 1); } Tensor& baddbmm_out_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor &result) { diff --git a/aten/src/ATen/native/cuda/DistributionBernoulli.cu b/aten/src/ATen/native/cuda/DistributionBernoulli.cu index 1acd09f18a1042..3acf87c3c4b406 100644 --- a/aten/src/ATen/native/cuda/DistributionBernoulli.cu +++ b/aten/src/ATen/native/cuda/DistributionBernoulli.cu @@ -35,7 +35,7 @@ void bernoulli_tensor_kernel(Tensor& self, const Tensor& p_, c10::optional gen) { - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); auto generator = get_generator_or_default(gen, cuda::detail::getDefaultCUDAGenerator()); at::native::templates::cuda::bernoulli_kernel(iter, p, generator); } diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h index 953a834887a491..481b5589684484 100644 --- a/aten/src/ATen/native/cuda/DistributionTemplates.h +++ b/aten/src/ATen/native/cuda/DistributionTemplates.h @@ -430,7 +430,7 @@ void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transfo template void normal_kernel(Tensor& self, double mean_, double std_, RNG gen) { - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIterator::borrowing_nullary_op(self); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] { using accscalar_t = at::acc_type; auto mean = static_cast(mean_); diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index 7c56368b08a2cb..51fe5c496bd91a 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -1,3 +1,4 @@ +#pragma once #include #include diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu index 0de012e8ecca66..e89a6e1ba0760d 100644 --- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu +++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu @@ -105,7 +105,7 @@ std::vector foreach_pointwise_op(TensorList input, TensorList tensors1, std::vector foreach_tensor_##NAME##_scalar_cuda(TensorList input, TensorList tensors1, TensorList tensors2, const Scalar& scalar) { \ check_foreach_api_restrictions(input, tensors1, tensors2); \ \ - if (!can_use_fast_route({input, tensors1, tensors2}, scalar)) { \ + if (!can_use_fast_route({input, tensors1, tensors2}, scalar) || has_integral_tensor(input, /* includeBool */ true)) { \ return at::native::foreach_tensor_##NAME##_scalar_slow(input, tensors1, tensors2, scalar); \ } \ \ @@ -115,7 +115,7 @@ std::vector foreach_tensor_##NAME##_scalar_cuda(TensorList input, Tensor void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, const Scalar& scalar) { \ check_foreach_api_restrictions(input, tensors1, tensors2); \ \ - if (!can_use_fast_route({input, tensors1, tensors2}, scalar)) { \ + if (!can_use_fast_route({input, tensors1, tensors2}, scalar) || has_integral_tensor(input, /* includeBool */ true)) { \ return at::native::foreach_tensor_##NAME##_scalar_slow_(input, tensors1, tensors2, scalar); \ } \ \ @@ -127,7 +127,7 @@ void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1, std::vector foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { \ check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ \ - if (!can_use_fast_route({input, tensors1, tensors2}, scalars)) { \ + if (!can_use_fast_route({input, tensors1, tensors2}, scalars) || has_integral_tensor(input, /* includeBool */ true)) { \ return at::native::foreach_tensor_##NAME##_scalarlist_slow(input, tensors1, tensors2, scalars); \ } \ \ @@ -137,7 +137,7 @@ std::vector foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, Te void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { \ check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ \ - if (!can_use_fast_route({input, tensors1, tensors2}, scalars)) { \ + if (!can_use_fast_route({input, tensors1, tensors2}, scalars) || has_integral_tensor(input, /* includeBool */ true)) { \ return at::native::foreach_tensor_##NAME##_scalarlist_slow_(input, tensors1, tensors2, scalars); \ } \ \ @@ -149,10 +149,14 @@ FOREACH_POINTWISE_OP_SCALAR(addcdiv, std::divides); FOREACH_POINTWISE_OP_SCALARLIST(addcmul, std::multiplies); FOREACH_POINTWISE_OP_SCALARLIST(addcdiv, std::divides); + +// Why bool tensors are pushed to slowpath? +// Because `AT_DISPATCH_ALL_TYPES_AND` is used below. +// TODO(mkozuki): Check whether it's possible to handle bool tensors in fastpath. #define FOREACH_MAXIMUM_MINIMUM_OP(NAME, OP) \ std::vector foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList tensors2) { \ check_foreach_api_restrictions(tensors1, tensors2); \ - if (!can_use_fast_route({tensors1, tensors2})) { \ + if (!can_use_fast_route({tensors1, tensors2}) || has_bool_tensor(tensors1)) { \ return at::native::foreach_tensor_##NAME##_slow(tensors1, tensors2); \ } \ \ diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index a037d2c0c47d2f..4c331c61512861 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -133,14 +133,14 @@ struct functor_name { \ #define OP_CUSTOM_FUNCTOR(function, op_name, functor_name) \ std::vector foreach_tensor_##op_name##_cuda(TensorList tensors) { \ check_foreach_api_restrictions(tensors); \ - if (!can_use_fast_route(tensors)) { \ + if (!can_use_fast_route(tensors) || has_integral_tensor(tensors, /* includeBool */ true)) { \ return at::native::foreach_tensor_##op_name##_slow(tensors); \ } \ return function(tensors); \ } \ void foreach_tensor_##op_name##_cuda_(TensorList tensors) { \ check_foreach_api_restrictions(tensors); \ - if (!can_use_fast_route(tensors)) { \ + if (!can_use_fast_route(tensors) || has_integral_tensor(tensors, /* includeBool */ true)) { \ return at::native::foreach_tensor_##op_name##_slow_(tensors); \ } \ \ @@ -155,9 +155,9 @@ OP_CUSTOM_FUNCTOR(function, op_name, functor_name); \ OP(floating_half_bfloat16, erfc, Erfc); OP(floating_half_bfloat16, expm1, Expm1); OP(floating_half, lgamma, Lgamma); -OP(floating_half, trunc, Truncf); -OP(floating_half, floor, Floor); -OP(floating_half, ceil, Ceil); +OP(floating_half_bfloat16, trunc, Truncf); +OP(floating_half_bfloat16, floor, Floor); +OP(floating_half_bfloat16, ceil, Ceil); OP(floating_complex_half_bfloat16, acos, Acos); OP(floating_complex_half_bfloat16, asin, Asin); @@ -205,8 +205,8 @@ struct Reciprocal { }; OP_CUSTOM_FUNCTOR(floating_half_bfloat16, sigmoid, Sigmoid) -OP_CUSTOM_FUNCTOR(floating_half, round, Round) -OP_CUSTOM_FUNCTOR(floating_half, frac, Trunc) +OP_CUSTOM_FUNCTOR(floating_half_bfloat16, round, Round) +OP_CUSTOM_FUNCTOR(floating_half_bfloat16, frac, Trunc) OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, reciprocal, Reciprocal) std::vector foreach_tensor_neg_cuda(TensorList tensors) { @@ -247,13 +247,9 @@ struct Abs { std::vector foreach_tensor_abs_cuda(TensorList tensors) { check_foreach_api_restrictions(tensors); - bool has_complex = false; - for (auto t : tensors) { - if (at::isComplexType(t.scalar_type())) { - has_complex = true; - } - } - + const bool has_complex = std::any_of( + tensors.begin(), tensors.end(), + [](const auto & t) { return at::isComplexType(t.scalar_type()); }); if (!can_use_fast_route(tensors) || has_complex) { return at::native::foreach_tensor_abs_slow(tensors); } @@ -263,13 +259,9 @@ std::vector foreach_tensor_abs_cuda(TensorList tensors) { void foreach_tensor_abs_cuda_(TensorList tensors) { check_foreach_api_restrictions(tensors); - bool has_complex = false; - for (auto t : tensors) { - if (at::isComplexType(t.scalar_type())) { - has_complex = true; - } - } - + const bool has_complex = std::any_of( + tensors.begin(), tensors.end(), + [](const auto & t) { return at::isComplexType(t.scalar_type()); }); if (!can_use_fast_route(tensors) || has_complex) { return at::native::foreach_tensor_abs_slow_(tensors); } diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 8343fdf2035a8d..91f4096ee6cb26 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -373,6 +373,10 @@ template void masked_scatter_cuda_impl(Tensor& self, const Tensor& mask, const Tensor& source){ auto srcSize = source.numel(); + if (self.numel() == 0) { + return; + } + auto mask_cont = mask.contiguous(); // Use a prefix sum to determine the output locations of the masked elements @@ -436,11 +440,6 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& so " and ", source.scalar_type()); - TensorArg self_arg{self, "self", 1}; - TensorArg mask_arg{mask, "mask", 2}; - TensorArg source_arg{source, "source", 3}; - checkAllSameGPU(__func__, {self_arg, mask_arg, source_arg}); - c10::MaybeOwned b_mask = expand_inplace(self, mask, "masked_scatter_"); if (b_mask->dtype() == ScalarType::Byte) { diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index ab54c11d212a23..cb89e14faa0e2e 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -1,3 +1,4 @@ +#pragma once #include #include #include diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index df3d2fe0d5622c..e6dc0f98b266ee 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -166,14 +166,14 @@ Tensor batch_norm_elementwise_backward_train( Tensor grad_input = at::empty(input.sizes(), grad_out.options()); auto iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(grad_out) - .add_input(input) - .add_input(weight_nd) - .add_input(mean_nd) - .add_input(invstd_nd) - .add_input(sum_dy_xmu_nd) - .add_input(sum_dy_nd) + .add_borrowed_output(grad_input) + .add_borrowed_input(grad_out) + .add_borrowed_input(input) + .add_borrowed_input(weight_nd) + .add_borrowed_input(mean_nd) + .add_borrowed_input(invstd_nd) + .add_borrowed_input(sum_dy_xmu_nd) + .add_borrowed_input(sum_dy_nd) .check_all_same_dtype(false) .promote_inputs_to_common_dtype(false) .build(); @@ -211,10 +211,10 @@ Tensor batch_norm_elementwise_backward_eval( strides[1] = weight.strides()[0]; auto weight_nd = weight.as_strided(shape, strides); auto iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(grad_out) - .add_input(invstd_nd) - .add_input(weight_nd) + .add_borrowed_output(grad_input) + .add_borrowed_input(grad_out) + .add_borrowed_input(invstd_nd) + .add_borrowed_input(weight_nd) .check_all_same_dtype(false) .promote_inputs_to_common_dtype(false) .build(); @@ -229,9 +229,9 @@ Tensor batch_norm_elementwise_backward_eval( }); } else { auto iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(grad_out) - .add_input(invstd_nd) + .add_borrowed_output(grad_input) + .add_borrowed_input(grad_out) + .add_borrowed_input(invstd_nd) .check_all_same_dtype(false) .promote_inputs_to_common_dtype(false) .build(); diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index 261e6aae69ec8e..43d6602ea8e2ae 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -7,7 +7,9 @@ __global__ static void compute_cuda_kernel( index_t* repeat_ptr, int64_t* cumsum_ptr, index_t* result_ptr, - int64_t size) { + int64_t size, + int64_t result_size) { + CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]); int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE; int warp_id = idx / C10_WARP_SIZE; @@ -15,6 +17,7 @@ __global__ static void compute_cuda_kernel( for (int64_t i = warp_id; i < size; i += stride) { int64_t end = cumsum_ptr[i]; index_t repeat = repeat_ptr[i]; + CUDA_KERNEL_ASSERT(repeat >= 0); int64_t start = end - repeat; for (int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) { result_ptr[j] = i; @@ -27,26 +30,29 @@ static void compute_cuda( index_t* repeat_ptr, int64_t* cumsum_ptr, index_t* result_ptr, - int64_t size) { + int64_t size, + int64_t result_size) { int64_t block = 512; int64_t warps_per_block = block / C10_WARP_SIZE; int64_t grid = std::min((size + warps_per_block - 1) / warps_per_block, 2048L); compute_cuda_kernel<<>>( - repeat_ptr, cumsum_ptr, result_ptr, size); + repeat_ptr, cumsum_ptr, result_ptr, size, result_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); } namespace at { namespace native { -Tensor repeat_interleave_cuda(const Tensor& repeat) { +Tensor repeat_interleave_cuda( + const Tensor& repeat, + c10::optional output_size) { Tensor output; AT_DISPATCH_INDEX_TYPES( repeat.scalar_type(), "repeat_interleave_cuda", [&]() { - output = - repeat_interleave_common>(repeat); + output = repeat_interleave_common>( + repeat, output_size); }); return output; } diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index 34cc78b67348ad..bbfb949038ba54 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -234,7 +234,7 @@ void cummax_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int TensorArg indices_arg{ indices, "indices", 2 }; TensorArg input_arg{ self, "input", 3 }; checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg}); - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "cummax_cuda", [&]() { scalar_t init = self.is_floating_point() ? (-1*std::numeric_limits::infinity()) : std::numeric_limits::lowest(); scan_dim_with_indices(self, values, indices, dim, init, std::greater_equal()); @@ -246,7 +246,7 @@ void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int TensorArg indices_arg{ indices, "indices", 2 }; TensorArg input_arg{ self, "input", 3 }; checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg}); - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "cummin_cuda", [&]() { scalar_t init = self.is_floating_point() ? std::numeric_limits::infinity() : std::numeric_limits::max(); scan_dim_with_indices(self, values, indices, dim, init, std::less_equal()); @@ -500,23 +500,25 @@ Tensor& _logcumsumexp_out_cuda(const Tensor& self, int64_t dim, Tensor& result) TensorArg input_arg{ self, "input", 2 }; checkAllSameGPU(__func__, {output_arg, input_arg}); - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, - self.scalar_type(), "logcumsumexp_cuda", [&]() { - using accscalar_t = acc_type; - scalar_t init = -std::numeric_limits::infinity(); - auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x, const scalar_t y) -> scalar_t { - scalar_t min = at::_isnan(y) ? y : std::min(x,y); //std::min returns first arg if one of the args is nan - scalar_t max = at::_isnan(y) ? y : std::max(x,y); //std::max returns first arg if one of the args is nan - if (min != max || ::isfinite(static_cast(min))) { - // nan will be propagated here - return ::log1p(std::exp(min - max)) + max; - } else { - // special case to correctly handle infinite inputs - return x; - } - }; - scan_dim(self, result, wrap_dim, init, log_add_exp); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + self.scalar_type(), "logcumsumexp_cuda", + [&]() { + using accscalar_t = acc_type; + scalar_t init = -std::numeric_limits::infinity(); + auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x, const scalar_t y) -> scalar_t { + scalar_t min = at::_isnan(y) ? y : std::min(x,y); //std::min returns first arg if one of the args is nan + scalar_t max = at::_isnan(y) ? y : std::max(x,y); //std::max returns first arg if one of the args is nan + if (min != max || ::isfinite(static_cast(min))) { + // nan will be propagated here + return ::log1p(std::exp(min - max)) + max; + } else { + // special case to correctly handle infinite inputs + return x; + } + }; + scan_dim(self, result, wrap_dim, init, log_add_exp); + }); return result; } @@ -543,8 +545,10 @@ Tensor& _cumsum_out_cuda(const Tensor& self, int64_t dim, Tensor& result) { } auto wrap_dim = maybe_wrap_dim(dim, self.dim()); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND( - at::ScalarType::Half, self.scalar_type(), "cumsum_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, ScalarType::BFloat16, + self.scalar_type(), "cumsum_cuda", + [&]() { scalar_t init = 0; scan_dim( self, @@ -579,8 +583,8 @@ Tensor& _cumprod_out_cuda(const Tensor& self, int64_t dim, Tensor& result) { } auto wrap_dim = maybe_wrap_dim(dim, self.dim()); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND( - at::ScalarType::Half, self.scalar_type(), "cumprod_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "cumprod_cuda", [&]() { scalar_t init = 1; scan_dim( self, diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index cf23671f6d5410..b19e0a98bddb58 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -267,12 +268,14 @@ std::tuple sort_out_stable_cuda(const Tensor & self, c10::opt } Tensor self_; + bool newself = false; if (is_non_overlapping_and_dense && self.stride(dim) == 1) { self_ = self; } else { auto new_strides_unsort = infer_dense_strides_dim_last(self, dim); self_ = at::empty_strided(self.sizes(), new_strides_unsort, self.options()); self_.copy_(self); + newself = true; } Tensor values_tmp, indices_tmp; @@ -290,11 +293,12 @@ std::tuple sort_out_stable_cuda(const Tensor & self, c10::opt "Unexpected dtype for values, expect ", self_.scalar_type(), ", got ", values.scalar_type()); values.resize_as_(self); } - if (values.strides() != self_.strides()) { + + if (values.strides() == self_.strides() && (newself || get_overlap_status(self, values) == MemOverlapStatus::NO)) { + values_ptr_ = values.data_ptr(); + } else { values_tmp = at::empty_strided(self_.sizes(), self_.strides(), self_.options()); values_ptr_ = values_tmp.data_ptr(); - } else { - values_ptr_ = values.data_ptr(); } if (!indices.defined()) { diff --git a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu index 06655d1bba0e3f..3f640d70c0d7d7 100644 --- a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu @@ -22,19 +22,25 @@ __host__ __device__ static inline std::complex ceil_wrapper(std::complex v } void ceil_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "ceil_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ceil_wrapper(a); - }); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + iter.dtype(), "ceil_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ceil_wrapper(a); + }); + }); } void frac_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "frac_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return a - ::trunc(a); - }); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + iter.dtype(), "frac_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return a - ::trunc(a); + }); + }); } // We manually overload floor because std::floor does not work with std::complex types. @@ -49,11 +55,14 @@ __host__ __device__ static inline std::complex floor_wrapper(std::complex } void floor_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "floor_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return floor_wrapper(a); - }); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + iter.dtype(), "floor_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return floor_wrapper(a); + }); + }); } template @@ -88,11 +97,14 @@ __host__ __device__ static inline c10::complex reciprocal_wrapper(c10::comple } void reciprocal_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "reciprocal_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return reciprocal_wrapper(a); - }); - }); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + iter.common_dtype(), "reciprocal_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return reciprocal_wrapper(a); + }); + }); } // We manually overload nearbyint because std::nearbyint does not work with std::complex types and ROCm. @@ -117,12 +129,15 @@ __host__ __device__ static inline c10::complex nearbyint_wrapper(c10::co #pragma pop void round_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "round_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - // We do not use std::round because we would like to round midway numbers to the nearest even integer. - return nearbyint_wrapper(a); - }); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + iter.dtype(), "round_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + // We do not use std::round because we would like to round midway numbers to the nearest even integer. + return nearbyint_wrapper(a); + }); + }); } // We manually overload trunc because std::trunc does not work with std::complex types and ROCm. @@ -144,11 +159,14 @@ __host__ __device__ static inline c10::complex trunc_wrapper(c10::comple } void trunc_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "trunc_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return trunc_wrapper(a); - }); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + iter.dtype(), "trunc_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return trunc_wrapper(a); + }); + }); } REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu index b6218b9f558173..85108c980c156b 100644 --- a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu @@ -31,7 +31,7 @@ void exp2_kernel_cuda(TensorIteratorBase& iter) { } void i0_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_i0(a); }); diff --git a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp index bbc1380dc0068f..bf84d583dbde56 100644 --- a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp @@ -100,6 +100,7 @@ class SparseCsrMKLInterface { retval); } + // res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols)) inline void sparse_mm( float* res, float* dense, @@ -108,19 +109,32 @@ class SparseCsrMKLInterface { MKL_INT nrows, MKL_INT ncols, MKL_INT dense_ncols) { - int stat = mkl_sparse_s_mm( + int stat; + if (dense_ncols == 1) { + stat = mkl_sparse_s_mv( + SPARSE_OPERATION_NON_TRANSPOSE, + alpha, + A, + desc, + dense, + beta, + res); + TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat); + } else { + stat = mkl_sparse_s_mm( SPARSE_OPERATION_NON_TRANSPOSE, alpha, A, desc, SPARSE_LAYOUT_ROW_MAJOR, dense, - dense_ncols, - dense_ncols, + nrows, + ncols, beta, res, dense_ncols); - TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat); + TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat); + } } inline void sparse_mm( @@ -131,19 +145,33 @@ class SparseCsrMKLInterface { MKL_INT nrows, MKL_INT ncols, MKL_INT dense_ncols) { - int stat = mkl_sparse_d_mm( + int stat; + if (dense_ncols == 1) { + stat = mkl_sparse_d_mv( + SPARSE_OPERATION_NON_TRANSPOSE, + alpha, + A, + desc, + dense, + beta, + res); + TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat); + } + else { + stat = mkl_sparse_d_mm( SPARSE_OPERATION_NON_TRANSPOSE, alpha, A, desc, SPARSE_LAYOUT_ROW_MAJOR, dense, - dense_ncols, - dense_ncols, + nrows, + ncols, beta, res, dense_ncols); - TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat); + TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat); + } } ~SparseCsrMKLInterface() { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bd54bc57c3dc2c..baa130a67b0838 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2156,6 +2156,7 @@ - func: _cufft_clear_plan_cache(int device_index) -> () - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + device_check: NoCheck # TensorIterator variants: function, method dispatch: CPU, CUDA: index @@ -2180,6 +2181,7 @@ variants: function, method - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) + device_check: NoCheck # delegate to _index_put_impl_, which leverages TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: index_put_ @@ -2190,6 +2192,7 @@ # - Tensor & Tensor::index_put_(std::initializer_list indices, Scalar v) - func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + device_check: NoCheck # delegate to _index_put_impl_ after clone, which leverages TensorIterator variants: function, method - func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) @@ -2877,15 +2880,15 @@ CUDA: miopen_rnn_backward - func: mm(Tensor self, Tensor mat2) -> Tensor + structured_delegate: mm.out variants: function, method dispatch: - CPU: mm_cpu - CUDA: mm_cuda SparseCPU, SparseCUDA, SparseCsrCPU: _sparse_mm - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + structured: True dispatch: - CPU: mm_cpu_out + CPU: mm_out_cpu CUDA: mm_out_cuda SparseCPU, SparseCUDA: _sparse_mm_out SparseCsrCPU: _sparse_csr_mm_out @@ -3331,16 +3334,16 @@ dispatch: CompositeExplicitAutograd: repeat -- func: repeat_interleave.Tensor(Tensor repeats) -> Tensor +- func: repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor variants: function dispatch: CPU: repeat_interleave_cpu CUDA: repeat_interleave_cuda -- func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None) -> Tensor +- func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, int? output_size=None) -> Tensor variants: function, method -- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None) -> Tensor +- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> Tensor variants: function, method - func: reshape(Tensor(a) self, int[] shape) -> Tensor(a) @@ -3523,6 +3526,31 @@ CPU, CUDA: silu_backward CompositeImplicitAutograd: math_silu_backward +- func: mish(Tensor self) -> Tensor + structured_delegate: mish.out + python_module: nn + dispatch: + CompositeExplicitAutograd: mish + +- func: mish_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: mish.out + python_module: nn + dispatch: + CompositeExplicitAutograd: mish_ + +- func: mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: mish_out + +- func: mish_backward(Tensor grad_output, Tensor self) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: mish_backward + CompositeImplicitAutograd: math_mish_backward + - func: sigmoid(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: sigmoid.out @@ -4109,7 +4137,7 @@ dispatch: MkldnnCPU: mkldnn_transpose_ -- func: one_hot(Tensor self, int num_classes=-1) -> Tensor +- func: one_hot(Tensor self, int num_classes=-1, ScalarType dtype=long) -> Tensor python_module: nn variants: function @@ -5701,18 +5729,6 @@ dispatch: CPU, CUDA: fmod_ -- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - device_check: NoCheck # TensorIterator - variants: method - dispatch: - CPU, CUDA: remainder_ - -- func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - device_check: NoCheck # TensorIterator - variants: method - dispatch: - CPU, CUDA: remainder_ - - func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) variants: method dispatch: @@ -6711,26 +6727,35 @@ CompositeExplicitAutograd: nextafter_ - func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) - device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: remainder_out + CompositeExplicitAutograd: remainder_out - func: remainder.Scalar(Tensor self, Scalar other) -> Tensor - device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: remainder + CompositeExplicitAutograd: remainder + +- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + dispatch: + CompositeExplicitAutograd: remainder_ - func: remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: remainder_out - func: remainder.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator + structured_delegate: remainder.Tensor_out variants: method, function - dispatch: - CPU, CUDA: remainder + +- func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: remainder.Tensor_out + variants: method - func: min(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -8020,10 +8045,16 @@ device_check: NoCheck # TensorIterator python_module: nn -- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor +- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: elu_backward + CPU, CUDA: elu_backward_out + +- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + structured_delegate: elu_backward.grad_input + python_module: nn - func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) structured_delegate: elu.out @@ -8076,10 +8107,16 @@ device_check: NoCheck # TensorIterator python_module: nn -- func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor +- func: hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: hardsigmoid_backward + CPU, CUDA: hardsigmoid_backward_out + +- func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor + structured_delegate: hardsigmoid_backward.grad_input + python_module: nn - func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8151,10 +8188,16 @@ dispatch: QuantizedCPU: leaky_relu_quantized_cpu -- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor +- func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: leaky_relu_backward + CPU, CUDA: leaky_relu_backward_out + +- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + structured_delegate: leaky_relu_backward.grad_input + python_module: nn - func: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) structured_delegate: leaky_relu.out @@ -8234,14 +8277,15 @@ python_module: nn - func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase python_module: nn dispatch: CPU, CUDA: softplus_backward_out - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor + structured_delegate: softplus_backward.grad_input python_module: nn - dispatch: - CPU, CUDA: softplus_backward - func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) structured: True diff --git a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp index 011e2e4777e37a..3abb1553179795 100644 --- a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp +++ b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp @@ -31,8 +31,8 @@ Tensor int_repr_quantized_cpu(const Tensor& self) { self.suggest_memory_format()); auto iter = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(dst) - .add_input(self) + .add_borrowed_output(dst) + .add_borrowed_input(self) .build(); cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; }); } diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 92be102346c76f..485cc0e36b4f24 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -1045,7 +1045,7 @@ void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) { auto self_scale_neg_zp_premul_vec = self_scale_vec * self_zero_point_vec.neg(); auto other_scale_zp_premul_vec = other_scale_vec * other_zero_point_vec.neg(); - auto iter = TensorIterator::binary_op(out, self, other); + auto iter = TensorIterator::borrowing_binary_op(out, self, other); AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qadd", [&]() { using Vec = Vectorized; @@ -1105,7 +1105,7 @@ void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) { float multiplier = self_scale * other_scale * inv_scale; - auto iter = TensorIterator::binary_op(out, self, other); + auto iter = TensorIterator::borrowing_binary_op(out, self, other); AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul", [&]() { using Vec = Vectorized; @@ -2119,9 +2119,9 @@ void fake_quantize_tensor_cachemask_kernel( auto iter_combined = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(output) - .add_output(mask) - .add_input(input) + .add_borrowed_output(output) + .add_borrowed_output(mask) + .add_borrowed_input(input) .build(); AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&] { @@ -2157,10 +2157,10 @@ void fake_quantize_learnable_tensor_grad_kernel_cpu( the order they are accessed follows the order they are built within the iterator. For example, if an iterator is built in the following order: auto iter = TensorIteratorConfig(). - .add_output(firstOutput) - .add_output(secondOutput) - .add_input(firstInput) - .add_input(secondInput) + .add_borrowed_output(firstOutput) + .add_borrowed_output(secondOutput) + .add_borrowed_input(firstInput) + .add_borrowed_input(secondInput) .build() data will contain 4 pointers to pointers to values in the following order: firstOutput, secondOutput, firstInput, secondInput. diff --git a/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp b/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp index 0464467989dd0c..2708d125088b81 100644 --- a/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp +++ b/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp @@ -83,8 +83,8 @@ std::tuple fake_quantize_per_channel_affine_cachemask( TensorIterator iter = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(Y) - .add_input(self) + .add_borrowed_output(Y) + .add_borrowed_input(self) .add_input(native::_unsafe_view(scale, expected_shape)) .add_input(native::_unsafe_view(zero_point, expected_shape)) .build(); @@ -93,8 +93,8 @@ std::tuple fake_quantize_per_channel_affine_cachemask( // for simplicity, as we do not expect this to be a bottleneck. TensorIterator iter_mask = TensorIteratorConfig() .check_all_same_dtype(false) - .add_output(mask) - .add_input(self) + .add_borrowed_output(mask) + .add_borrowed_input(self) .add_input(native::_unsafe_view(scale, expected_shape)) .add_input(native::_unsafe_view(zero_point, expected_shape)) .build(); @@ -227,13 +227,13 @@ std::tuple _fake_quantize_learnable_per_channel_affine_b auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask, numDimensions)).expand(X_shape); auto iter = TensorIteratorConfig() - .add_output(dX) - .add_output(dScale_vec) - .add_output(dZeroPoint_vec) - .add_input(X) - .add_input(dY) - .add_input(scale_vectorized) - .add_input(zero_point_vectorized) + .add_borrowed_output(dX) + .add_borrowed_output(dScale_vec) + .add_borrowed_output(dZeroPoint_vec) + .add_borrowed_input(X) + .add_borrowed_input(dY) + .add_borrowed_input(scale_vectorized) + .add_borrowed_input(zero_point_vectorized) .build(); fake_quant_grad_learnable_channel_stub( diff --git a/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp b/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp index c42336be9bcfc9..5b0f23a9d738f3 100644 --- a/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp +++ b/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp @@ -179,11 +179,11 @@ std::tuple _fake_quantize_learnable_per_tensor_affine_ba auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve); auto iter = TensorIteratorConfig() - .add_output(dX) - .add_output(dScale_vec) - .add_output(dZeroPoint_vec) - .add_input(X) - .add_input(dY) + .add_borrowed_output(dX) + .add_borrowed_output(dScale_vec) + .add_borrowed_output(dZeroPoint_vec) + .add_borrowed_input(X) + .add_borrowed_input(dY) .build(); fake_quant_grad_learnable_tensor_stub( diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index 92a10115b27a10..750440fc7a84e9 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -20,165 +21,155 @@ using namespace at::sparse_csr; // certain utiliy functions are usable from sparse COO. using namespace at::sparse; -static constexpr bool is_msvc() { +static constexpr bool is_mkl_supported() { #ifdef _MSC_VER - return true; -#else return false; +#elif __APPLE__ || __MACH__ + return false; +#else + return true; #endif } +// Only accept squares sparse matrices or dense input as a vector +// TODO: Check what happens with MKL, the output error reported with non square matrices tends to be high +// See: https://github.com/pytorch/pytorch/issues/58770 +bool is_square_or_vec(int64_t dim_i, int64_t dim_j, int64_t dim_k) { + return (dim_i == dim_k && dim_k == dim_j) || (dim_i == dim_j && dim_k == 1); +} + +template +void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, Scalar beta, const Tensor& t, Scalar alpha, const Tensor& csr, const Tensor& col_indices, const Tensor& values, const Tensor& dense) { + + scalar_t cast_alpha = alpha.to(); + scalar_t cast_beta = beta.to(); + if (cast_beta == 0) { + r.zero_(); + } else if (cast_beta == 1) { + if (!is_same_tensor(r, t)) { + r.copy_(t); + } + } else { + at::mul_out(r, t, scalar_to_tensor(beta)); + } + AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() { + auto csr_accessor = csr.accessor(); + auto col_indices_accessor = col_indices.accessor(); + + auto values_accessor = values.accessor(); + scalar_t* dense_ptr = dense.data(); + scalar_t* r_ptr = r.data(); + + int64_t dense_stride0 = dense.stride(0); + int64_t dense_stride1 = dense.stride(1); + int64_t r_stride0 = r.stride(0); + int64_t r_stride1 = r.stride(1); + + at::parallel_for( + 0, + dim_i, + internal::GRAIN_SIZE, + [&](int64_t irow_start, int64_t irow_end) { + for (index_t h = irow_start; h < irow_end; ++h) { + index_t i_start = csr_accessor[h]; + index_t i_end = csr_accessor[h+1]; + for (index_t i = i_start; i < i_end; i++) { + scalar_t val = values_accessor[i]; + index_t col = col_indices_accessor[i]; + at::native::cpublas::axpy(dim_k, + cast_alpha * val, + dense_ptr + col * dense_stride0, dense_stride1, + r_ptr + h * r_stride0, r_stride1); + } + } + }); + }); +} + // Functions for matrix multiplication. Tensor& addmm_out_sparse_csr_dense_cpu( const Tensor& self, - const SparseCsrTensor& op1, - const Tensor& op2, + const SparseCsrTensor& sparse, + const Tensor& dense, const Scalar& beta, const Scalar& alpha, - Tensor& out) { - AT_ASSERT(op1.is_sparse_csr()); - Tensor expand_self = *expand_size(self, {op1.size(0), op2.size(1)}, "addmm_out_sparse_csr"); + Tensor& r) { + TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr()); + Tensor t = *expand_size(self, {sparse.size(0), dense.size(1)}, "addmm_out_sparse_csr"); - AT_ASSERT(expand_self.device().type() == kCPU); + TORCH_INTERNAL_ASSERT(t.device().type() == kCPU); TORCH_CHECK( - out.device().type() == kCPU, + r.device().type() == kCPU, "addmm: expected 'out' to be CPU tensor, but got CUDA tensor"); TORCH_CHECK( - op1.device().type() == kCPU, + sparse.device().type() == kCPU, "addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor"); TORCH_CHECK( - op2.device().type() == kCPU, + dense.device().type() == kCPU, "addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor"); TORCH_CHECK( - op1.dim() == 2, + sparse.dim() == 2, "addmm: 2-D matrices expected, got ", - op1.dim(), + sparse.dim(), "D tensor"); TORCH_CHECK( - op2.dim() == 2, + dense.dim() == 2, "addmm: 2-D matrices expected, got ", - op2.dim(), + dense.dim(), "D tensor"); TORCH_CHECK( - out.is_contiguous(), + r.is_contiguous(), "out argument must be contiguous, but got: ", - out.suggest_memory_format()); + r.suggest_memory_format()); - // ixk * kxj = ixj - int64_t dim_i = op1.size(0); - int64_t dim_j = op2.size(1); - int64_t dim_k = op1.size(1); + // ixj * jxk = ixk + int64_t dim_i = sparse.size(0); + int64_t dim_j = sparse.size(1); + int64_t dim_k = dense.size(1); TORCH_CHECK( - op2.size(0) == dim_k, + dense.size(0) == dim_j, "addmm: Expected dense matrix (op2) size(0)=", - dim_k, + dim_j, ", got ", - op2.size(0)); + dense.size(0)); TORCH_CHECK( - op1.size(1) == dim_k, + sparse.size(1) == dim_j, "addmm: Expected sparse matrix (op1) size(1)=", - dim_k, + dim_j, ", got ", - op1.size(1)); - out.resize_({dim_i, dim_j}); + sparse.size(1)); + resize_output(r, {dim_i, dim_k}); + auto col_indices = sparse.col_indices(); + auto crow_indices = sparse.crow_indices(); + auto values = sparse.values(); + int64_t nnz = sparse._nnz(); - auto col_indices = op1.col_indices(); - auto crow_indices = op1.crow_indices(); - auto values = op1.values(); - - AT_DISPATCH_FLOATING_TYPES( - values.scalar_type(), "addmm_sparse_csr_dense", [&] { + // Do not use MKL for Windows due to linking issues with sparse MKL routines. + if (at::hasMKL() && is_mkl_supported() && is_square_or_vec(dim_i, dim_j, dim_k)) { + AT_DISPATCH_FLOATING_TYPES(values.type(), "addmm_sparse_dense", [&] { scalar_t cast_beta = beta.to(); - if (!is_same_tensor(out, expand_self)) { - out.copy_(expand_self); - } if (cast_beta == 0) { - out.zero_(); + r.zero_(); + } else if (cast_beta == 1) { + if (!is_same_tensor(r, t)) { + r.copy_(t); + } } else { - at::mul_out(out, expand_self, scalar_to_tensor(beta)); + at::mul_out(r, t, scalar_to_tensor(beta)); } - }); - - // Do not use MKL for Windows due to linking issues with sparse MKL routines. - if (at::hasMKL() && !is_msvc()) { - _sparse_mm_mkl_(out, op1, op2, expand_self, alpha, beta); + // r = r + alpha * sparse * dense + _sparse_mm_mkl_(r, sparse, dense, t, alpha, Scalar(static_cast(1.0))); + }); } else { - int64_t dense_stride0 = op1.stride(0); - int64_t dense_stride1 = op1.stride(1); - int64_t out_stride0 = out.stride(0); - int64_t out_stride1 = out.stride(1); - - AT_DISPATCH_FLOATING_TYPES( - values.scalar_type(), - "sparse_csr_mm_cpu", - [&alpha, - &beta, - &op1, - &out, - &values, - &crow_indices, - &col_indices, - &dense_stride0, - &dense_stride1, - &out_stride0, - &out_stride1, - &dim_k]() { - AT_DISPATCH_INDEX_TYPES( - crow_indices.scalar_type(), - "csr_mm_crow_indices", - [&alpha, - &beta, - &op1, - &out, - &values, - &crow_indices, - &col_indices, - &dense_stride0, - &dense_stride1, - &out_stride0, - &out_stride1, - &dim_k]() { - scalar_t cast_alpha = alpha.to(); - // NOLINTNEXTLINE(clang-diagnostic-unused-variable) - scalar_t cast_beta = beta.to(); - scalar_t* dense_ptr = op1.data_ptr(); - scalar_t* out_ptr = out.data_ptr(); - - auto col_indices_accessor = col_indices.accessor(); - auto crow_indices_accessor = - crow_indices.accessor(); - auto values_accessor = values.accessor(); - - at::parallel_for( - 0, - crow_indices.size(0) - 1, - internal::GRAIN_SIZE, - [&](int64_t irow_start, int64_t irow_end) { - for (int irow = irow_start; irow < irow_end; ++irow) { - int start_index = crow_indices_accessor[irow]; - int end_index = crow_indices_accessor[irow + 1]; - - for (int i = start_index; i < end_index; ++i) { - auto val = values_accessor[i]; - auto icol = col_indices_accessor[i]; - - at::native::cpublas::axpy( - dim_k, - cast_alpha * val, - dense_ptr + icol * dense_stride0, - dense_stride1, - out_ptr + irow * out_stride0, - out_stride1); - } - } - }); - }); - }); + // r = beta * t + alpha * sparse * dense + AT_DISPATCH_FLOATING_TYPES(values.type(), "addmm_sparse_dense", [&] { + s_addmm_out_sparse_dense_worker(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, crow_indices, col_indices, values, dense); + }); } - return out; + return r; } Tensor addmm_sparse_csr_dense_cpu( @@ -229,9 +220,9 @@ Tensor& add_out_dense_sparse_csr_cpu( const Tensor& dense, const SparseCsrTensor& src, const Scalar& alpha) { - AT_ASSERT(dense.layout() == kStrided); - AT_ASSERT(src.is_sparse_csr()); - AT_ASSERT(dense.device() == kCPU); + TORCH_INTERNAL_ASSERT(dense.layout() == kStrided); + TORCH_INTERNAL_ASSERT(src.is_sparse_csr()); + TORCH_INTERNAL_ASSERT(dense.device() == kCPU); TORCH_CHECK( out.is_contiguous(), @@ -263,11 +254,12 @@ Tensor& add_out_dense_sparse_csr_cpu( out.scalar_type(), " in add operation"); - auto src_values = src.values().to(commonDtype); + auto src_values = src.values(); auto src_crow_indices = src.crow_indices(); auto src_col_indices = src.col_indices(); - out.resize_as_(dense); + resize_output(out, dense.sizes()); + Tensor resultBuffer = out; Tensor valuesBuffer = src_values.to(commonDtype); @@ -280,21 +272,21 @@ Tensor& add_out_dense_sparse_csr_cpu( AT_DISPATCH_ALL_TYPES( commonDtype, "add_out_op2_sparse_csr", - [&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() { + [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() { AT_DISPATCH_INDEX_TYPES( src_crow_indices.scalar_type(), "csr_add_out_crow_indices", - [&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() { - auto values_accessor = src_values.accessor(); - scalar_t* out_ptr = out.data_ptr(); + [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() { + auto values_accessor = valuesBuffer.accessor(); + scalar_t* out_ptr = resultBuffer.data_ptr(); scalar_t cast_value = alpha.to(); auto crow_indices_accessor = src_crow_indices.accessor(); auto col_indices_accessor = src_col_indices.accessor(); - auto out_strides0 = out.strides()[0]; - auto out_strides1 = out.strides()[1]; + auto out_strides0 = resultBuffer.strides()[0]; + auto out_strides1 = resultBuffer.strides()[1]; for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1; ++irow) { @@ -303,13 +295,16 @@ Tensor& add_out_dense_sparse_csr_cpu( for (int i = start_index; i < end_index; ++i) { auto icol = col_indices_accessor[i]; - auto index = out.storage_offset() + irow * out_strides0 + + auto index = resultBuffer.storage_offset() + irow * out_strides0 + icol * out_strides1; out_ptr[index] += cast_value * values_accessor[i]; } } }); }); + if (out.scalar_type() != commonDtype) { + out.copy_(resultBuffer); + } return out; } diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 9344f5a755c497..52899d1ca2ff59 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -255,7 +255,7 @@ static SparseTensor& coalesce_(SparseTensor& tensor) { // values=[1., 1.] (after truncation), which sum to 2.f instead of 3.f. // To perform floor division the sparse tensor must be coalesced first. -SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, c10::optional rounding_mode, SparseTensor& r) { +SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, c10::optional rounding_mode, SparseTensor& r) { TORCH_CHECK(value.dim() == 0, "Sparse division requires a scalar or ", "zero-dim dense tensor divisor (got shape ", value.sizes(), " for divisor)"); TORCH_CHECK(!value.is_sparse(), "Sparse division requires a scalar or ", @@ -309,7 +309,7 @@ SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseT return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); } -Tensor div_sparse(const Tensor& self, const Tensor& value, c10::optional rounding_mode) { +Tensor div_sparse(const Tensor& self, const Tensor& value, c10::optional rounding_mode) { auto commonDtype = at::result_type(self, value); if (c10::isIntegralType(commonDtype, /*include_bool=*/true) && !rounding_mode.has_value()) { commonDtype = typeMetaToScalarType(at::get_default_dtype()); @@ -318,11 +318,11 @@ Tensor div_sparse(const Tensor& self, const Tensor& value, c10::optional rounding_mode) { +Tensor& div_sparse_(Tensor& self, const Tensor& value, c10::optional rounding_mode) { return div_out_sparse_zerodim(self, value, std::move(rounding_mode), self); } -SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, c10::optional rounding_mode, SparseTensor& r) { +SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, c10::optional rounding_mode, SparseTensor& r) { return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), std::move(rounding_mode), r); } diff --git a/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl index 88373605d010ac..615e9ebb28750b 100644 --- a/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl @@ -1,41 +1,46 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 inputSize; - ivec4 outputSize; - ivec2 kernelSize; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; ivec2 stride; ivec2 padding; ivec2 dilate; -} -uConstBlock; - -#define UP_DIV(x, y) (((x) + (y)-1) / (y)) -#define FLT_MAX 3.402823466e+38 +} uBlock; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#define FLT_MIN -3.402823466e+38 + void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 outputSize = uConstBlock.outputSize.xyz; - if (all(lessThan(pos, outputSize))) { - ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding; - ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate))); - ivec2 efxy = - min(uConstBlock.kernelSize, - UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate)); - - vec4 v = vec4(-FLT_MAX); - for (int kyi = sfxy.y; kyi < efxy.y; ++kyi) { - for (int kxi = sfxy.x; kxi < efxy.x; ++kxi) { - ivec2 ixy = s0 + ivec2(kxi, kyi); - v = max(texelFetch(uInput, ivec3(ixy.x, ixy.y, pos.z), 0), v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + const ivec2 start = ipos; + const ivec2 end = ipos + uBlock.kernel.xy * uBlock.dilate.xy; + + vec4 outtex = vec4(FLT_MIN); + + for (int y = start.y; y < end.y; y += uBlock.dilate.y) { + for (int x = start.x; x < end.x; x += uBlock.dilate.x) { + if ((x >= 0 && x < uBlock.kernel.z) && (y >= 0 && y < uBlock.kernel.w)) { + outtex = max(texelFetch(uInput, ivec3(x, y, pos.z), 0), outtex); + } + else { + outtex = max(vec4(0), outtex); + } } } - imageStore(uOutput, pos, v); + imageStore(uOutput, pos, outtex); } } diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp index 9144b102888f7f..27b4e7181003cd 100644 --- a/aten/src/ATen/native/vulkan/ops/Pool.cpp +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -94,14 +94,14 @@ Tensor adaptive_avg_pool2d( return convert(v_output); } -Tensor avg_pool2d( +Tensor pool2d( const Tensor& self_arg, const IntArrayRef kernel_arg, IntArrayRef stride_arg, const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, const bool ceil_mode, - const bool /* count_include_pad */, - const c10::optional /* divisor_override */) { + const api::Shader::Descriptor& shader_descriptor) { if (stride_arg.empty()) { stride_arg = kernel_arg; } @@ -121,7 +121,7 @@ Tensor avg_pool2d( const auto kernel = normalize(kernel_arg); const auto stride = normalize(stride_arg); const auto padding = normalize(padding_arg); - const auto dilation = std::array{1, 1}; + const auto dilation = normalize(dilation_arg); const int64_t output_height = pooling_output_shape( input_size[Layout::Activation4D::height], @@ -182,6 +182,7 @@ Tensor avg_pool2d( ivec4 kernel; ivec2 stride; ivec2 padding; + ivec2 dilation; } block { v_output.extents(), safe_downcast( @@ -201,6 +202,10 @@ Tensor avg_pool2d( safe_downcast(padding[Layout::Parameter::width]), safe_downcast(padding[Layout::Parameter::height]), }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, }; context->dispatch( @@ -210,7 +215,7 @@ Tensor avg_pool2d( VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }, - VK_KERNEL(avg_pool2d), + shader_descriptor, v_output.extents(), context->gpu().adapter->local_work_group_size(), // Write-only access bypasses synchronization but inserts appropriate @@ -237,11 +242,49 @@ Tensor avg_pool2d( return convert(v_output); } +Tensor avg_pool2d( + const Tensor& self_arg, + const IntArrayRef kernel_arg, + IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const bool ceil_mode, + const bool /* count_include_pad */, + const c10::optional /* divisor_override */) { + return pool2d( + self_arg, + kernel_arg, + stride_arg, + padding_arg, + {1,1}, + ceil_mode, + VK_KERNEL(avg_pool2d) + ); +} + +Tensor max_pool2d( + const Tensor& self_arg, + const IntArrayRef kernel_arg, + IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, + const bool ceil_mode) { + return pool2d( + self_arg, + kernel_arg, + stride_arg, + padding_arg, + dilation_arg, + ceil_mode, + VK_KERNEL(max_pool2d) + ); +} + #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl("_adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d)); m.impl("avg_pool2d", TORCH_FN(avg_pool2d)); + m.impl("max_pool2d", TORCH_FN(max_pool2d)); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/templates/DispatchKeyNativeFunctions.h b/aten/src/ATen/templates/DispatchKeyNativeFunctions.h new file mode 100644 index 00000000000000..f872fbcc7df1bf --- /dev/null +++ b/aten/src/ATen/templates/DispatchKeyNativeFunctions.h @@ -0,0 +1,8 @@ +#pragma once +// ${generated_comment} + +#include + +namespace ${cpp_namespace} { +${dispatch_declarations} +} // namespace ${cpp_namespace} diff --git a/aten/src/ATen/templates/aten_xla_type.h b/aten/src/ATen/templates/aten_xla_type.h deleted file mode 100644 index 4dc34bc1a7760b..00000000000000 --- a/aten/src/ATen/templates/aten_xla_type.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once -// ${generated_comment} - -#include - -namespace ${cpp_namespace} { - -// Base ATEN Type class where the XLA specific overrides should be defined. -class AtenXlaType { - public: - static void InitializeAtenBindings(); - - ////////////////////////////////////////////////////////////////////////////// - // ATEN API ovverrides in alphabetical order. - // Note: The C++ signatures must match the ones listed within the following - // pytorch folder file: - // torch/csrc/autograd/generated/RegistrationDeclarations.h - ///////////////////////////////////////////////////////////////////////////// -${dispatch_xla_declarations} -}; - -} // namespace torch_xla diff --git a/aten/src/ATen/templates/aten_xla_type_default.cpp b/aten/src/ATen/templates/aten_xla_type_default.cpp index 78503cf8448c73..040a752156eac9 100644 --- a/aten/src/ATen/templates/aten_xla_type_default.cpp +++ b/aten/src/ATen/templates/aten_xla_type_default.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include namespace ${cpp_namespace} { diff --git a/aten/src/ATen/test/math_kernel_test.cpp b/aten/src/ATen/test/math_kernel_test.cpp index 005c11cb0eaa86..8c01688825cd16 100644 --- a/aten/src/ATen/test/math_kernel_test.cpp +++ b/aten/src/ATen/test/math_kernel_test.cpp @@ -110,6 +110,15 @@ TEST(MathKernelTest, SiluBackward) { ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(MathKernelTest, MishBackward) { + const auto input = rand({20, 10}); + const auto grad_output = rand({20, 10}); + auto out = at::native::mish_backward(grad_output, input); + auto math_out = at::native::math_mish_backward(grad_output, input); + ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6); +} + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(MathKernelTest, NarrowCopy) { auto x = rand({5, 8, 7}); diff --git a/aten/src/ATen/test/test_thread_pool_guard.cpp b/aten/src/ATen/test/test_thread_pool_guard.cpp index 24575fb381ff5e..33e4144c141712 100644 --- a/aten/src/ATen/test/test_thread_pool_guard.cpp +++ b/aten/src/ATen/test/test_thread_pool_guard.cpp @@ -3,7 +3,6 @@ #include #include - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(TestThreadPoolGuard, TestThreadPoolGuard) { auto threadpool_ptr = caffe2::pthreadpool_(); @@ -30,3 +29,33 @@ TEST(TestThreadPoolGuard, TestThreadPoolGuard) { ASSERT_NE(threadpool_ptr4, nullptr); ASSERT_EQ(threadpool_ptr4, threadpool_ptr); } + +TEST(TestThreadPoolGuard, TestRunWithGuard) { + const std::vector array = {1, 2, 3}; + + // Run via pthreadpool_parallelize_1d + int64_t outer = 0; + auto fn1 = [&array, &outer](const size_t task_id) { + outer += array[task_id]; + }; + auto pool = caffe2::pthreadpool(); + pool->run(fn1, 3); + + int64_t inner = 0; + { + // Run on same thread + caffe2::_NoPThreadPoolGuard g1; + auto fn2 = [&array, &inner](const size_t task_id) { + inner += array[task_id]; + }; + pool->run(fn2, 3); + + // confirm the guard is on + auto threadpool_ptr1 = caffe2::pthreadpool_(); + ASSERT_EQ(threadpool_ptr1, nullptr); + } + ASSERT_NE(outer, 0); + ASSERT_NE(inner, 0); + ASSERT_EQ(outer, 6); + ASSERT_EQ(inner, 6); +} diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 347a00a3e7a3d1..cac1336ec2d509 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -974,6 +974,24 @@ TEST(VulkanAPITest, hardswish_) { ASSERT_TRUE(check); } +TEST(VulkanAPITest, max_pool2d) { + if (!at::is_vulkan_available()) { + return; + } + c10::InferenceMode mode; + + const auto in_cpu = at::rand({5, 13, 55, 68}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::max_pool2d(in_cpu, {3, 4}, {2, 1}, {1, 1}, {1, 1}, false); + const auto out_vulkan = at::max_pool2d(in_cpu.vulkan(), {3, 4}, {2, 1}, {1, 1}, {1,1}, false); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + TEST(VulkanAPITest, mean) { const auto in_cpu = at::rand({17, 3, 79, 53}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); const auto out_cpu = at::mean(in_cpu, {-1, -2}, true); diff --git a/aten/src/TH/THAllocator.cpp b/aten/src/TH/THAllocator.cpp index e07677aea76a97..a45e9142ddf7d1 100644 --- a/aten/src/TH/THAllocator.cpp +++ b/aten/src/TH/THAllocator.cpp @@ -35,19 +35,19 @@ typedef struct { } THMapInfo; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -const char * unknown_filename = "filename not specified"; +const std::string unknown_filename = "filename not specified"; #ifdef _WIN32 -const char * unknown_eventname = "eventname not specified"; +const std::string unknown_eventname = "eventname not specified"; #endif -THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, size_t size) - : filename_(filename ? filename : unknown_filename) +THMapAllocator::THMapAllocator(WithFd, std::string filename, int fd, int flags, size_t size) + : filename_(filename.empty() ? unknown_filename : std::move(filename)) , flags_(0) // to be filled later , size_(0) // to be filled later #ifdef _WIN32 , handle_(INVALID_HANDLE_VALUE) // to be filled later , event_(INVALID_HANDLE_VALUE) // to be filled later - , eventname_(filename ? std::string(filename) + "_event" : unknown_eventname) + , eventname_(filename.empty() ? unknown_eventname : (filename + "_event")) #else , fd_(fd) #endif @@ -338,8 +338,8 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, c10::reportMemoryUsageToProfiler(base_ptr_, size_, c10::Device(c10::DeviceType::CPU)); } -THMapAllocator::THMapAllocator(const char *filename, int flags, size_t size) - : THMapAllocator(WITH_FD, filename, -1, flags, size) +THMapAllocator::THMapAllocator(std::string filename, int flags, size_t size) + : THMapAllocator(WITH_FD, std::move(filename), -1, flags, size) {} #ifdef _WIN32 @@ -404,11 +404,11 @@ void THMapAllocator::close() { #else /* defined(_WIN32) || defined(HAVE_MMAP) */ -THMapAllocator::THMapAllocator(const char *filename, int flags, size_t size) { +THMapAllocator::THMapAllocator(std::string filename, int flags, size_t size) { AT_ERROR("file mapping not supported on your system"); } -THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, size_t size) { +THMapAllocator::THMapAllocator(WithFd, std::string filename, int fd, int flags, size_t size) { AT_ERROR("file mapping not supported on your system"); } @@ -554,8 +554,8 @@ THRefcountedMapAllocator* THRefcountedMapAllocator::fromDataPtr(const at::DataPt return dptr.cast_context(&deleteTHRefcountedMapAllocator); } -at::DataPtr THMapAllocator::makeDataPtr(const char *filename, int flags, size_t size, size_t* actual_size_out) { - auto* context = new THMapAllocator(filename, flags, size); +at::DataPtr THMapAllocator::makeDataPtr(std::string filename, int flags, size_t size, size_t* actual_size_out) { + auto* context = new THMapAllocator(std::move(filename), flags, size); if (actual_size_out) *actual_size_out = context->size(); return {context->data(), context, &deleteTHMapAllocator, at::DeviceType::CPU}; } diff --git a/aten/src/TH/THAllocator.h b/aten/src/TH/THAllocator.h index 4a4e385281e759..db34ffd2c49d87 100644 --- a/aten/src/TH/THAllocator.h +++ b/aten/src/TH/THAllocator.h @@ -23,8 +23,8 @@ enum WithFd { WITH_FD }; class TORCH_API THMapAllocator { public: - THMapAllocator(const char *filename, int flags, size_t size); - THMapAllocator(WithFd, const char *filename, int fd, int flags, size_t size); + THMapAllocator(std::string filename, int flags, size_t size); + THMapAllocator(WithFd, std::string filename, int fd, int flags, size_t size); THMapAllocator(const THMapAllocator&) = delete; THMapAllocator& operator=(const THMapAllocator&) = delete; THMapAllocator(THMapAllocator&&) = delete; @@ -45,7 +45,7 @@ class TORCH_API THMapAllocator { virtual void* data() const { return base_ptr_; } static THMapAllocator* fromDataPtr(const at::DataPtr&); - static at::DataPtr makeDataPtr(const char *filename, int flags, size_t size, size_t* actual_size_out); + static at::DataPtr makeDataPtr(std::string filename, int flags, size_t size, size_t* actual_size_out); static at::DataPtr makeDataPtr(WithFd, const char *filename, int fd, int flags, size_t size, size_t* actual_size_out); // Closes the data. Helps us avoid destructor shenanigans diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp index 55e48601673e18..1f09b1dbac5c1e 100644 --- a/benchmarks/cpp/tensorexpr/bench_approx.cpp +++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp @@ -12,17 +12,18 @@ using namespace torch::jit::tensorexpr; void vectorize(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target, int width) { auto loops = ln->getLoopStmtsFor(target); - For *outer, *inner, *tail; - ln->splitWithTail(loops[0], width, &outer, &inner, &tail); + For *inner, *tail; + ln->splitWithTail(loops[0], width, &inner, &tail); ln->vectorize(inner); } void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target) { std::vector loops = ln->getLoopStmtsFor(target); - For *outer, *inner, *tail; - ln->splitWithTail(loops[0], 16 * 8, &outer, &inner, &tail); + For *inner, *tail; + ln->splitWithTail(loops[0], 16 * 8, &inner, &tail); + For* outer = loops[0]; ln->vectorize(inner); - ln->splitWithTail(outer, 8, &outer, &inner, &tail); + ln->splitWithTail(outer, 8, &inner, &tail); Stmt* unrolled; LoopNest::unroll(inner, &unrolled); } diff --git a/benchmarks/cpp/tensorexpr/bench_gemm.cpp b/benchmarks/cpp/tensorexpr/bench_gemm.cpp index 78855264a5b46b..792d457c2f23ac 100644 --- a/benchmarks/cpp/tensorexpr/bench_gemm.cpp +++ b/benchmarks/cpp/tensorexpr/bench_gemm.cpp @@ -81,16 +81,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); te::For* m = loops[0]; - te::For* mo; - te::For* mi; - loop.splitWithMask(m, 32, &mo, &mi); + loop.splitWithMask(m, 32); } { auto const& loops = loop.getLoopStmtsFor(CT); te::For* n = loops[2]; - te::For* no; - te::For* ni; - loop.splitWithMask(n, 32, &no, &ni); + loop.splitWithMask(n, 32); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k @@ -145,16 +141,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); te::For* m = loops[0]; - te::For* mo; - te::For* mi; - loop.splitWithMask(m, 4, &mo, &mi); + loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); te::For* n = loops[2]; - te::For* no; - te::For* ni; - loop.splitWithMask(n, 16, &no, &ni); + loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k @@ -209,16 +201,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); te::For* m = loops[0]; - te::For* mo; - te::For* mi; - loop.splitWithMask(m, 4, &mo, &mi); + loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); te::For* n = loops[2]; - te::For* no; - te::For* ni; - loop.splitWithMask(n, 16, &no, &ni); + loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k @@ -281,16 +269,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); te::For* m = loops[0]; - te::For* mo; - te::For* mi; - loop.splitWithMask(m, 4, &mo, &mi); + loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); te::For* n = loops[2]; - te::For* no; - te::For* ni; - loop.splitWithMask(n, 16, &no, &ni); + loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k diff --git a/benchmarks/cpp/tensorexpr/bench_reduce.cpp b/benchmarks/cpp/tensorexpr/bench_reduce.cpp index 39462de17ff74a..d04681391760a5 100644 --- a/benchmarks/cpp/tensorexpr/bench_reduce.cpp +++ b/benchmarks/cpp/tensorexpr/bench_reduce.cpp @@ -266,10 +266,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(BT); te::For* m = loops[1]; - te::For* mo; - te::For* mi; - te::For* tail; - loop.splitWithTail(m, kChunkSize, &mo, &mi, &tail); + loop.splitWithTail(m, kChunkSize); } loop.prepareForCodegen(); @@ -310,9 +307,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(BT); te::For* m = loops[1]; - te::For* mo; - te::For* mi; - loop.splitWithMask(m, kChunkSize, &mo, &mi); + loop.splitWithMask(m, kChunkSize); } loop.prepareForCodegen(); @@ -354,9 +349,9 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { auto loops = loop.getLoopStmtsFor(BT); TORCH_CHECK(loops.size() == 1); - te::For* mo; te::For* mi; - loop.splitWithMask(loops.at(0), kChunkSize, &mo, &mi); + loop.splitWithMask(loops.at(0), kChunkSize, &mi); + te::For* mo = loops.at(0); loop.reorderAxis(mo, mi); loops = loop.getLoopStmtsFor(BT); diff --git a/benchmarks/distributed/rpc/parameter_server/launcher.py b/benchmarks/distributed/rpc/parameter_server/launcher.py index 2360b63ced5ea9..ddfbb8fd5952fd 100644 --- a/benchmarks/distributed/rpc/parameter_server/launcher.py +++ b/benchmarks/distributed/rpc/parameter_server/launcher.py @@ -215,14 +215,9 @@ def run_benchmark(rank, model, data, config): def get_json_config(file_name, id): - f = open( - os.path.join( - Path(__file__).parent, file_name - ), - "r" - ) - json_config = json.load(f)[id] - f.close() + with open(os.path.join(Path(__file__).parent, file_name), "r") as f: + json_config = json.load(f)[id] + return json_config diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index 79e72365c4ee89..6a53c9c97b3c49 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -1,7 +1,7 @@ import time import json import torch -import cpp_extension # noqa: F401 +import benchmark_cpp_extension # noqa: F401 """PyTorch performance microbenchmarks. diff --git a/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py b/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py index 4b8f5d7fb7bc20..9ada10b3d8b0d0 100644 --- a/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py +++ b/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py @@ -1,6 +1,6 @@ import unittest -import cpp_extension # noqa: F401 +import benchmark_cpp_extension # noqa: F401 import torch diff --git a/benchmarks/operator_benchmark/pt_extension/extension.cpp b/benchmarks/operator_benchmark/pt_extension/extension.cpp index 2dbdfdd8b3e66a..22a45275b4f0ed 100644 --- a/benchmarks/operator_benchmark/pt_extension/extension.cpp +++ b/benchmarks/operator_benchmark/pt_extension/extension.cpp @@ -22,7 +22,7 @@ TORCH_LIBRARY_FRAGMENT(operator_benchmark, m) { m.def("_consume.list", &consume_list); } -PYBIND11_MODULE(cpp_extension, m) { +PYBIND11_MODULE(benchmark_cpp_extension, m) { m.def("_consume", &consume, "consume"); m.def("_consume_list", &consume_list, "consume_list"); } diff --git a/benchmarks/operator_benchmark/pt_extension/setup.py b/benchmarks/operator_benchmark/pt_extension/setup.py index 8177364e485bae..ff2d182236a779 100644 --- a/benchmarks/operator_benchmark/pt_extension/setup.py +++ b/benchmarks/operator_benchmark/pt_extension/setup.py @@ -1,6 +1,6 @@ from setuptools import setup from torch.utils.cpp_extension import CppExtension, BuildExtension -setup(name='cpp_extension', - ext_modules=[CppExtension('cpp_extension', ['extension.cpp'])], +setup(name='benchmark_cpp_extension', + ext_modules=[CppExtension('benchmark_cpp_extension', ['extension.cpp'])], cmdclass={'build_ext': BuildExtension}) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 72eb3c09a5fb03..67c81c5dbce5d5 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -349,3 +349,21 @@ const auto clamp_script_2 = R"JIT( a = torch.clamp(inp, min, max) return (a) )JIT"; + +const auto full_like_script = R"JIT( + def forward(self, + a: Tensor, + fill_value: int, + dtype: Optional[int], + layout: Optional[int], + device: Optional[Device], + pin_memory: Optional[bool], + memory_format: Optional[int]): + return torch.full_like(a, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format) +)JIT"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 48e64839d816bd..8190b73e18cbbf 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -352,6 +352,15 @@ TEST(StaticRuntime, IndividualOps_to) { test_to(at::ScalarType::Half, false, true, c10::MemoryFormat::Preserve); } +TEST(StaticRuntime, IndividualOps_FullLike) { + auto a = at::randn({2, 3}); + auto dtype = at::ScalarType::Int; + auto cpu = at::Device(DeviceType::CPU); + std::vector args {a, 4, dtype, at::kStrided, cpu, false, + c10::MemoryFormat::Contiguous}; + testStaticRuntime(full_like_script, args); +} + TEST(StaticRuntime, LongModel) { torch::jit::Module mod = getLongScriptModel(); auto a = torch::randn({2, 2}); diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index ff6a84ebbef0d0..b3d31531694cef 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -228,7 +228,7 @@ enum class DispatchKey : uint8_t { // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed // and inputs are saved for backward in the post-autocast type. - // AutocastCPU, + AutocastCPU, AutocastCUDA, // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 272cf331182849..8358e931f0937f 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -78,8 +78,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) { switch (t) { - // case DispatchKey::CPU: - // return DispatchKeySet(DispatchKey::AutocastCPU); + case DispatchKey::CPU: + return DispatchKeySet(DispatchKey::AutocastCPU); case DispatchKey::CUDA: return DispatchKeySet(DispatchKey::AutocastCUDA); default: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 223355203cf13f..e11572f23a7abc 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -223,7 +223,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ }); constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ - // DispatchKey::AutocastCPU, + DispatchKey::AutocastCPU, DispatchKey::AutocastCUDA, }); @@ -234,7 +234,7 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({ }); constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ - // DispatchKey::AutocastCPU, + DispatchKey::AutocastCPU, DispatchKey::AutocastCUDA, }); diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 802bf17e0411ac..4c0baa431d53c6 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -63,8 +63,9 @@ class C10_API Scalar { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR) // also support scalar.to(); + // Deleted for unsupported types, but specialized below for supported types template - T to() const; + T to() const = delete; #undef DEFINE_ACCESSOR bool isFloatingPoint() const { @@ -186,11 +187,6 @@ class C10_API Scalar { }; // define the scalar.to() specializations -template -inline T Scalar::to() const { - throw std::runtime_error("to() cast to unexpected type."); -} - #define DEFINE_TO(T, name) \ template <> \ inline T Scalar::to() const { \ diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 3cabbbeed5b32c..d652db5a215c6f 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -130,31 +130,31 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) _(float, Float) \ _(double, Double) -#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype( \ - ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), \ +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE>::t), \ SCALARTYPE) -#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype( \ - ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \ - SCALARTYPE1) \ - _(decltype( \ - ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \ +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ SCALARTYPE2) #define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ @@ -165,14 +165,14 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) _(int64_t, Long) \ _(float, Float) \ _(double, Double) \ - _(decltype( \ - ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ SCALARTYPE1) \ - _(decltype( \ - ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ SCALARTYPE2) \ - _(decltype( \ - ::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ SCALARTYPE3) #define AT_FORALL_QINT_TYPES(_) \ @@ -428,4 +428,8 @@ inline std::ostream& operator<<( return stream << toString(scalar_type); } +#define AT_FORAUTOCAST_SCALAR_TYPES(_) \ + _(half, Half) /* 0 */ \ + _(bfloat16, BFloat16) /* 1 */ + } // namespace c10 diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index c3cd4e6b17befa..a0c7673641effe 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -21,6 +21,18 @@ C10_DEFINE_int64( namespace c10 { +namespace impl { + +static std::string noop_name_fn(const PyInterpreter*) { + return ""; +} + +void PyInterpreter::disarm() noexcept { + name_fn_ = &noop_name_fn; +} + +} // namespace impl + const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed = "is not allowed on a Tensor created from .data or .detach().\n" "If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n" @@ -85,6 +97,8 @@ TensorImpl::TensorImpl( DispatchKeySet key_set, const caffe2::TypeMeta data_type) : storage_(std::move(storage)), + pyobj_interpreter_(nullptr), + pyobj_(nullptr), storage_offset_(0), numel_(0), data_type_(data_type), @@ -111,6 +125,8 @@ TensorImpl::TensorImpl( const caffe2::TypeMeta data_type, c10::optional device_opt) : storage_(std::move(storage)), + pyobj_interpreter_(nullptr), + pyobj_(nullptr), storage_offset_(0), numel_(0), data_type_(data_type), diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 59c9397c8a6277..e383ffb4c57a53 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -175,6 +175,110 @@ struct C10_API AutogradMetaFactoryRegisterer { } }; +// Note [Python interpreter tag] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// We store a PyObject on TensorImpl so that we can efficiently translate +// tensors into the Python representations. However, in some situations +// (torchdeploy) there may be multiple Python interpreters in a single process +// and we must take care not to accidentally mix up PyObjects with the wrong +// interpreters. Thus, we also tag every TensorImpl with the Python interpreter +// it corresponds to. +// +// With torchdeploy, we have these invariants: +// - Any given TensorImpl can be associated with AT MOST one Python +// interpreter. +// We represent the interpreter tag as a memory address to an instance of +// a virtual class that is allocated once per interpreter (this is so that +// we can request the interpreter to perform operations for us, if +// necessary). +// - A given TensorImpl's interpreter tag can only go from uninitialized to +// tagged; once tagged, this is a quiescent state (once tagged to an +// interpreter, ALWAYS tagged to that interpreter) +// - A thread may mutate the PyObject field of a TensorImpl if and only if it +// holds the GIL for the interpreter tagged on the TensorImpl. (If the +// TensorImpl is not tagged, it must first atomically claim its tag before it +// can validly write) + +// The PyInterpreter object itself is a class that contains some function +// pointers for interacting with the interpreter. For now this is just for +// debugging, but if a Tensor can own a PyObject, the interpreter can be used to +// free it. +// +// WARNING: This class has to be written very carefully, because it may be +// possible for a Tensor to have a reference an interpreter corresponding to +// a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling +// virtual methods very dangerous, because the vtable may be garbage at that +// point (on a good day, you might get "pure virtual method called"). +// +// The idea to solve this problem is we always leak PyInterpreters (so they +// always stay live even after dlclose), and disarm the "virtual methods" by +// replacing them with function pointers that just no-op. This can't be done +// with a traditional C++ vtable, so we have to roll our own. +// +// NB: The downside with representing PyInterpreter tags as full objects is that +// it takes an extra word on TensorImpl. If tags were instead just integer +// indices, on 64-bit architectures we could pack the tag and PyObject together +// into a single atomic word. On 32-bit architectures we could simply say that +// only one Python interpreter is supported (erroring if a nontrivial +// interpreter tag is attempted to be set). +// +// The difficulty with this scheme is we need to maintain an out-of-line table +// to get at the PyInterpreters so that we can do virtual method calls on them, +// and registration/deregistration to this table must be done in a thread safe +// manner. This can be easily done if the number of possible PyInterpreters is +// small enough (e.g., 8-bit integer) by simply preallocating an array of +// sufficient size to hold all possible interpreters. Surely 128 threads is +// more than enough for anyone! +// +// I didn't decide to do this technique at the moment, because the extra word +// added by the PyInterpreter tag takes us to 24 words, which means that we +// still fit inside three eight word cache lines. If you need to penny pinch +// another word consider doing this! + +struct PyInterpreter; +struct C10_API PyInterpreter { + using name_sig = std::string(const PyInterpreter*); + + PyInterpreter(name_sig* name_fn) : name_fn_(name_fn) {} + + // For debugging purposes only + name_sig* name_fn_; + + std::string name() const { + return (*name_fn_)(this); + } + + // Disarm this PyInterpreter, making all of its methods noops. + // Because the function pointers are raw pointers (not atomics), + // a disarm() invocation that is concurrent with active destructors + // is not thread safe and will trigger TSAN. My hope is that this + // situations doesn't ever actually happen; tensor destruction should + // quiesce when a dlclose happens, and any long lived tensors whose + // destructors would be disarmed here only begin the destruction process + // on process shutdown (long after the dlclose has occurred). + void disarm() noexcept; +}; + +// PyInterpreterStatus describes what the state of its interpreter tag +// is, relative to the thread currently holding the GIL. +enum class PyInterpreterStatus { + // We just allocated the Tensor, it hasn't escaped to other threads, + // we know that it definitely hasn't been tagged to be associated + // with an interpreter. + DEFINITELY_UNINITIALIZED, + // We queried the interpreter field and it looked uninitialized. But + // another thread may have raced with us to tag it with some other + // interpreter id. So we will have to do a CEX to make sure we can + // actually nab it. + MAYBE_UNINITIALIZED, + // We queried the interpreter field and it was tagged to belong to us. + // This means we have sole write access (as we hold the GIL for this + // interpreter) + TAGGED_BY_US, + // Someone else tagged this. We can't use this TensorImpl from Python. + TAGGED_BY_OTHER, +}; + } // namespace impl struct C10_API NamedTensorMetaInterface { @@ -1301,12 +1405,96 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { version_counter_.bump(); } - inline void set_pyobj(PyObject* pyobj) noexcept { + // Associate the TensorImpl with the specified PyObject, and, if necessary, + // also tag the interpreter. + // + // NB: This lives in a header so that we can inline away the switch on status + // + // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after + // PyObject if necessary! + void init_pyobj( + impl::PyInterpreter* self_interpreter, + PyObject* pyobj, + c10::impl::PyInterpreterStatus status) { + impl::PyInterpreter* expected = nullptr; + switch (status) { + case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED: + // caller guarantees there is no multithreaded access; if there is + // no data race OK to do a relaxed store + pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + break; + case impl::PyInterpreterStatus::TAGGED_BY_US: + // no tagging is necessary, the tag is already correct + break; + case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED: + // attempt to claim this TensorImpl with the specified interpreter + // tag + if (pyobj_interpreter_.compare_exchange_strong( + expected, self_interpreter, std::memory_order_acq_rel)) { + break; + } + // test if, actually, it was already tagged by us! this situation can't + // be caused by a race, but it could be caused by a situation + // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED + // (because they didn't pre-check the tag) when actually it was + // owned by the interpreter + if (expected == self_interpreter) { + break; + } + // fallthrough, we lost the race. We are guaranteed not to lose the + // race with ourself, as calls to init_pyobj with the same interpreter + // ID must be sequentialized by the GIL + C10_FALLTHROUGH; + case impl::PyInterpreterStatus::TAGGED_BY_OTHER: + TORCH_CHECK( + false, + "cannot allocate PyObject for Tensor on interpreter ", + self_interpreter, + " that has already been used by another torch deploy interpreter ", + pyobj_interpreter_.load()); + } + + // we are the ONLY thread that can have gotten to this point. It is not + // possible to conflict with another zero interpreter as access is protected + // by GIL pyobj_ = pyobj; } - inline PyObject* pyobj() const noexcept { - return pyobj_; + // Test the interpreter tag. If tagged for the current interpreter, return + // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, + // returns a nullopt. If it is definitely invalid, raises an error. + // + // NB: this lives in header so that we can avoid actually creating the + // c10::optional + c10::optional check_pyobj(impl::PyInterpreter* self_interpreter) { + // Note [Memory ordering on Python interpreter tag] + impl::PyInterpreter* interpreter = + pyobj_interpreter_.load(std::memory_order_acquire); + if (interpreter == nullptr) { + // NB: This never returns DEFINITELY_UNINITIALIZED because there is + // always the possibility that another thread races to initialize + // after we query here. The only time when we can conclude a tensor + // is definitely uninitialized is when we have just allocated it and + // it cannot have escaped to other threads yet + return c10::nullopt; + } else if (interpreter == self_interpreter) { + // NB: pyobj_ could still be null! + return c10::make_optional(pyobj_); + } else { + TORCH_CHECK( + false, + "cannot access PyObject for Tensor on interpreter ", + self_interpreter->name(), + " that has already been used by another torch deploy interpreter ", + pyobj_interpreter_.load()->name()); + } + } + + // Clear the PyObject field for an interpreter, in situations where we + // statically know the tensor is tagged with our interpreter. + void unchecked_clear_pyobj(impl::PyInterpreter* interpreter) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load()); + pyobj_ = nullptr; } private: @@ -1835,6 +2023,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return n; } + /** + * Compute the number of elements based on the sizes of a + * tensor. Catches integer overflow that may occur when a tensor + * using a sparse layout has multiple dimensions with large sizes. + */ + int64_t safe_compute_numel() const { + int64_t n = 1; + for (auto s : sizes()) { + TORCH_CHECK( + s == 0 || n <= std::numeric_limits::max() / s, + "numel: integer multiplication overflow"); + n *= s; + } + return n; + } + /** * Compute whether or not a tensor is contiguous based on the sizes and * strides of a tensor. @@ -1853,12 +2057,27 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { protected: /** - * Recompute the cached numel of a tensor. Call this if you modify sizes. + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. + * + * For tensors with sparse layouts, use safe_refresh_numel() instead + * because it will catch integer overflow that may occur for tensors + * with sparse layouts and large dimensions. */ void refresh_numel() { numel_ = compute_numel(); } + /** + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. Use only for tensors with sparse layouts because only + * sparse tensor are likely to have sizes that may lead to integer + * overflow when computing numel. + */ + void safe_refresh_numel() { + numel_ = safe_compute_numel(); + } + /** * Recompute the cached contiguity of a tensor. Call this if you modify sizes * or strides. @@ -2002,19 +2221,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { c10::VariableVersion version_counter_; + // This field contains the interpreter tag for this object. See + // Note [Python interpreter tag] for general context + // + // Note [Memory ordering on Python interpreter tag] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // What memory_order do we need when accessing this atomic? We don't + // need a single total modification order (as provided by + // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only + // transition from -1 to some positive integer and never changes afterwards. + // Because there is only one modification, it trivially already has a total + // modification order (e.g., we don't need fences or locked instructions on + // x86) + // + // In fact, one could make a reasonable argument that relaxed reads are OK, + // due to the presence of external locking (GIL) to ensure that interactions + // with other data structures are still correctly synchronized, so that + // we fall in the "Single-Location Data Structures" case as described in + // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf + // However, on x86, it doesn't matter if I use acquire or relaxed on the load + // as I get the same assembly in both cases. So I just use the more + // conservative acquire (which will impede compiler optimizations but I don't + // care) + std::atomic pyobj_interpreter_; + // This field contains a weak reference to a PyObject representing // this Tensor. It MUST NOT be a strong reference, as that would // create a reference cycle between Tensor and the PyObject. If // pyobj is nullptr, when we transfer Tensor to Python, we allocate - // a new PyObject for it and set this field. This is thread safe - // because all Python code is protected under the GIL. This design does - // NOT WORK for Tensors which are shared across multiple Python - // subinterpreters (introduced in Python 3.8) since you don't have - // enough space to store the separate PyObject per subinterpreter. + // a new PyObject for it and set this field. This field does not + // have to be protected by an atomic as it is only allowed to be + // accessed when you hold the GIL. + // // When a PyObject dies, you are obligated to clear this field // (otherwise, you will try to use-after-free the pyobj); this currently // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp - PyObject* pyobj_ = nullptr; + PyObject* pyobj_; c10::impl::SizesAndStrides sizes_and_strides_; @@ -2159,6 +2401,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // autograd metadata pointer // named tensor metadata pointer // version counter pointer +// Python interpreter pointer // PyObject pointer // SizesAndStrides size/pointer // SizesAndStrides sizes (pre-allocated 0) @@ -2178,7 +2421,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // static_assert( sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 23, + sizeof(TensorImpl) == sizeof(int64_t) * 24, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/thread_pool.cpp b/c10/core/thread_pool.cpp index 122732fd1b2913..845989a2bf4a5a 100644 --- a/c10/core/thread_pool.cpp +++ b/c10/core/thread_pool.cpp @@ -43,6 +43,7 @@ size_t ThreadPool::size() const { } size_t ThreadPool::numAvailable() const { + std::unique_lock lock(mutex_); return available_; } diff --git a/c10/core/thread_pool.h b/c10/core/thread_pool.h index 3fadd2add895d7..af06b4ba679f7c 100644 --- a/c10/core/thread_pool.h +++ b/c10/core/thread_pool.h @@ -57,7 +57,7 @@ class C10_API ThreadPool : public c10::TaskThreadPoolBase { std::queue tasks_; std::vector threads_; - std::mutex mutex_; + mutable std::mutex mutex_; std::condition_variable condition_; std::condition_variable completed_; std::atomic_bool running_; diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 90356425167d52..90eb7599610f62 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -333,6 +333,13 @@ __host__ __device__ #else // __APPLE__, _MSC_VER #if defined(NDEBUG) extern "C" { +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void __assert_fail( + const char* expr, + const char* file, + unsigned int line, + const char* func); +#else // __SYCL_DEVICE_ONLY__ #if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) || \ defined(__HIP_ARCH__) || defined(__HIP__) __host__ __device__ @@ -343,6 +350,7 @@ __host__ __device__ const char* file, unsigned int line, const char* function) throw(); +#endif } #endif // NDEBUG #define CUDA_KERNEL_ASSERT(cond) \ diff --git a/c10/util/FunctionRef.h b/c10/util/FunctionRef.h index 929fd6552c29f1..f0ba874a8c0d73 100644 --- a/c10/util/FunctionRef.h +++ b/c10/util/FunctionRef.h @@ -40,8 +40,8 @@ class function_ref { template static Ret callback_fn(intptr_t callable, Params... params) { - return (*reinterpret_cast(callable))( - std::forward(params)...); + return (*reinterpret_cast(callable))(std::forward( + params)...); } public: diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index 3b2bc04eda981d..38b26534552701 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -966,12 +966,14 @@ class SmallVector : public SmallVectorImpl { typename Container, std::enable_if_t< std::is_convertible< - typename std::iterator_traits().begin())>::iterator_category, + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, std::input_iterator_tag>::value && std::is_convertible< - typename std::iterator_traits().end())>::iterator_category, + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, std::input_iterator_tag>::value, int> = 0> explicit SmallVector(Container&& c) : SmallVectorImpl(N) { @@ -1003,12 +1005,14 @@ class SmallVector : public SmallVectorImpl { typename Container, std::enable_if_t< std::is_convertible< - typename std::iterator_traits().begin())>::iterator_category, + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, std::input_iterator_tag>::value && std::is_convertible< - typename std::iterator_traits().end())>::iterator_category, + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, std::input_iterator_tag>::value, int> = 0> const SmallVector& operator=(const Container& RHS) { @@ -1037,12 +1041,14 @@ class SmallVector : public SmallVectorImpl { typename Container, std::enable_if_t< std::is_convertible< - typename std::iterator_traits().begin())>::iterator_category, + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, std::input_iterator_tag>::value && std::is_convertible< - typename std::iterator_traits().end())>::iterator_category, + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, std::input_iterator_tag>::value, int> = 0> const SmallVector& operator=(Container&& C) { diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index 69dcf99dc90a75..59e4509ca43d74 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -134,7 +135,7 @@ inline static bool isPrint(char s) { return s > 0x1f && s < 0x7f; } -inline void printQuotedString(std::ostream& stmt, const std::string& str) { +inline void printQuotedString(std::ostream& stmt, const string_view str) { stmt << "\""; for (auto s : str) { switch (s) { diff --git a/c10/util/thread_name.cpp b/c10/util/thread_name.cpp index c503df112492be..27cbb98de89432 100644 --- a/c10/util/thread_name.cpp +++ b/c10/util/thread_name.cpp @@ -2,7 +2,12 @@ #include -#if defined(__GLIBC__) && !defined(__APPLE__) && !defined(__ANDROID__) +#ifndef __GLIBC_PREREQ +#define __GLIBC_PREREQ(x, y) 0 +#endif + +#if defined(__GLIBC__) && __GLIBC_PREREQ(2, 12) && !defined(__APPLE__) && \ + !defined(__ANDROID__) #define C10_HAS_PTHREAD_SETNAME_NP #endif diff --git a/c10/util/variant.h b/c10/util/variant.h index 8f571e62c08476..080ceb4eae491a 100644 --- a/c10/util/variant.h +++ b/c10/util/variant.h @@ -677,8 +677,8 @@ inline constexpr auto invoke(F&& f, Args&&... args) MPARK_RETURN( void_t(), std::declval()...))>, F, Args...> - : identity(), std::declval()...))> {}; + : identity(), std::declval()...))> {}; } // namespace detail_ diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index 73fbddbb371ef4..51d6a164ee6aa3 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -36,7 +36,7 @@ C10_DEFINE_bool( C10_DEFINE_bool( caffe2_serialize_using_bytes_as_holder, false, - "Serialize BOOL, UINT8, INT8, UINT16, INT16, INT64, FLOAT16 tensors using byte_data field instead of int32"); + "Serialize BOOL, UINT8, INT8, UINT16, INT16, FLOAT16 tensors using byte_data field instead of int32"); namespace caffe2 { namespace { diff --git a/caffe2/operators/batch_sparse_to_dense_op.cc b/caffe2/operators/batch_sparse_to_dense_op.cc index 8af1932519c83d..89642ea34d5b6e 100644 --- a/caffe2/operators/batch_sparse_to_dense_op.cc +++ b/caffe2/operators/batch_sparse_to_dense_op.cc @@ -3,16 +3,17 @@ namespace caffe2 { template <> +template void BatchSparseToDenseOp::FillInDenseValues( const int64_t batch_size, const int64_t indice_lengths, - const int64_t* lengths_data, - const int64_t* indices_data, + const TLen* lengths_data, + const TInd* indices_data, const float* values_data, float* output_data, CPUContext* /*context*/) { - int64_t lengths_sum = 0; - math::Sum( + TLen lengths_sum = 0; + math::Sum( batch_size, lengths_data, &lengths_sum, &context_); CAFFE_ENFORCE_EQ(lengths_sum, indice_lengths); @@ -33,16 +34,17 @@ void BatchSparseToDenseOp::FillInDenseValues( } template <> +template void BatchDenseToSparseOp::FillInSparseValues( const int64_t batch_size, const int64_t indice_lengths, - const int64_t* lengths_data, - const int64_t* indices_data, + const TLen* lengths_data, + const TInd* indices_data, const float* dense_data, float* output_data, CPUContext* /*context*/) { - int64_t lengths_sum = 0; - math::Sum( + TLen lengths_sum = 0; + math::Sum( batch_size, lengths_data, &lengths_sum, &context_); CAFFE_ENFORCE_EQ(lengths_sum, indice_lengths); diff --git a/caffe2/operators/batch_sparse_to_dense_op.cu b/caffe2/operators/batch_sparse_to_dense_op.cu index 1711504ca5e178..aea2035a5d3d45 100644 --- a/caffe2/operators/batch_sparse_to_dense_op.cu +++ b/caffe2/operators/batch_sparse_to_dense_op.cu @@ -8,8 +8,9 @@ namespace caffe2 { namespace { +template void array_prefix_sum_inclusive( - const int64_t* dev_array, + const TLen* dev_array, const int num_items, Tensor& prefix_buffer, Tensor& prefix_sum, @@ -21,32 +22,33 @@ void array_prefix_sum_inclusive( nullptr, temp_storage_bytes, dev_array, - prefix_sum.mutable_data(), + prefix_sum.mutable_data(), num_items, context.cuda_stream()); // Allocate temporary storage - auto buffer_size = (temp_storage_bytes + sizeof(int64_t)) / sizeof(int64_t); + auto buffer_size = (temp_storage_bytes + sizeof(TLen)) / sizeof(TLen); prefix_buffer.Resize(buffer_size); void* dev_temp_storage = - static_cast(prefix_buffer.mutable_data()); + static_cast(prefix_buffer.mutable_data()); // Inclusive sum cub::DeviceScan::InclusiveSum( dev_temp_storage, temp_storage_bytes, dev_array, - prefix_sum.mutable_data(), + prefix_sum.mutable_data(), num_items, context.cuda_stream()); } +template __global__ void FillInDenseValuesKernel( const int64_t batch_size, const int64_t dense_last_dim, - const int64_t* indices_data, + const TInd* indices_data, const float* values_data, - const int64_t* L_cum_sum_data, + const TLen* L_cum_sum_data, float* output_data) { CUDA_1D_KERNEL_LOOP(idx, batch_size) { int offset_start = idx == 0 ? 0 : L_cum_sum_data[idx - 1]; @@ -60,12 +62,13 @@ __global__ void FillInDenseValuesKernel( } } +template __global__ void FillInSparseValuesKernel( const int64_t batch_size, const int64_t dense_last_dim, - const int64_t* indices_data, + const TInd* indices_data, const float* dense_data, - const int64_t* L_cum_sum_data, + const TLen* L_cum_sum_data, float* output_data) { CUDA_1D_KERNEL_LOOP(idx, batch_size) { int offset_start = idx == 0 ? 0 : L_cum_sum_data[idx - 1]; @@ -81,21 +84,22 @@ __global__ void FillInSparseValuesKernel( } // namespace template <> +template void BatchSparseToDenseOp::FillInDenseValues( const int64_t batch_size, const int64_t indice_lengths, - const int64_t* lengths_data, - const int64_t* indices_data, + const TLen* lengths_data, + const TInd* indices_data, const float* values_data, float* output_data, CUDAContext* context) { // calculate the prefix sum of the length array - array_prefix_sum_inclusive( + array_prefix_sum_inclusive( lengths_data, batch_size, len_prefix_tmp_, len_prefix_sum_, context_); // launch the gpu kernel for to fill in dense values const int64_t min_size = 1; - FillInDenseValuesKernel<<< + FillInDenseValuesKernel<<< CAFFE_GET_BLOCKS(std::max(batch_size, min_size)), CAFFE_CUDA_NUM_THREADS, 0, @@ -104,27 +108,28 @@ void BatchSparseToDenseOp::FillInDenseValues( dense_last_dim_, indices_data, values_data, - len_prefix_sum_.data(), + len_prefix_sum_.data(), output_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template <> +template void BatchDenseToSparseOp::FillInSparseValues( const int64_t batch_size, const int64_t indice_lengths, - const int64_t* lengths_data, - const int64_t* indices_data, + const TLen* lengths_data, + const TInd* indices_data, const float* dense_data, float* output_data, CUDAContext* context) { // calculate the prefix sum of the length array - array_prefix_sum_inclusive( + array_prefix_sum_inclusive( lengths_data, batch_size, len_prefix_tmp_, len_prefix_sum_, context_); // launch the gpu kernel for to fill in sparse values const int64_t min_size = 1; - FillInSparseValuesKernel<<< + FillInSparseValuesKernel<<< CAFFE_GET_BLOCKS(std::max(batch_size, min_size)), CAFFE_CUDA_NUM_THREADS, 0, @@ -133,7 +138,7 @@ void BatchDenseToSparseOp::FillInSparseValues( dense_last_dim_, indices_data, dense_data, - len_prefix_sum_.data(), + len_prefix_sum_.data(), output_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/caffe2/operators/batch_sparse_to_dense_op.h b/caffe2/operators/batch_sparse_to_dense_op.h index db488a9411282b..b6e51a9ce79745 100644 --- a/caffe2/operators/batch_sparse_to_dense_op.h +++ b/caffe2/operators/batch_sparse_to_dense_op.h @@ -13,13 +13,42 @@ template class BatchSparseToDenseOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + USE_DISPATCH_HELPER; template explicit BatchSparseToDenseOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int64_t, "dense_last_dim", dense_last_dim_, -1), OP_SINGLE_ARG(T, "default_value", default_value_, static_cast(0)) {} - bool RunOnDevice() { + + bool RunOnDevice() override { + return DispatchHelper>::call( + this, Input(LENGTHS)); + } + + private: + template + void FillInDenseValues( + const int64_t batch_size, + const int64_t indice_lengths, + const TLen* lengths_data, + const TInd* indices_data, + const T* values_data, + T* output_data, + Context* context); + + template + bool DoRunWithType() { + return DispatchHelper< + TensorTypes2< + int32_t, + int64_t, + GenericTensorImplementation>, + TLen>::call(this, Input(INDICES)); + } + + template + bool DoRunWithType2() { auto& lengths = Input(LENGTHS); auto& indices = Input(INDICES); auto& values = Input(VALUES); @@ -27,12 +56,10 @@ class BatchSparseToDenseOp : public Operator { CAFFE_ENFORCE_EQ(indices.numel(), values.numel()); CAFFE_ENFORCE_EQ(lengths.dim(), 1); CAFFE_ENFORCE_EQ(indices.dim(), 1); - - const int64_t* lengths_data = lengths.template data(); - const int64_t* indices_data = indices.template data(); + const TLen* lengths_data = lengths.template data(); + const TInd* indices_data = indices.template data(); const T* values_data = values.template data(); int64_t batch_size = lengths.numel(); - vector output_shape = {batch_size}; if (InputSize() == 4) { auto& shaper = Input(3); @@ -68,15 +95,16 @@ class BatchSparseToDenseOp : public Operator { return true; } - private: - void FillInDenseValues( - const int64_t batch_size, - const int64_t indice_lengths, - const int64_t* lengths_data, - const int64_t* indices_data, - const T* values_data, - T* output_data, - Context* context); + template + bool DoRunWithOtherType2() { + CAFFE_THROW( + "BatchSparseToDense is not implemented on values of type ", + Input(VALUES).dtype().name(), + " with lengths of type ", + Input(LENGTHS).dtype().name(), + " and indices of type ", + Input(INDICES).dtype().name()); + } int64_t dense_last_dim_; T default_value_; @@ -92,10 +120,39 @@ template class BatchDenseToSparseOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + USE_DISPATCH_HELPER; + template explicit BatchDenseToSparseOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() { + return DispatchHelper>::call( + this, Input(LENGTHS)); + } + + private: + template + void FillInSparseValues( + const int64_t batch_size, + const int64_t indice_lengths, + const TLen* lengths_data, + const TInd* indices_data, + const T* dense_data, + T* output_data, + Context* context); + + template + bool DoRunWithType() { + return DispatchHelper< + TensorTypes2< + int32_t, + int64_t, + GenericTensorImplementation>, + TLen>::call(this, Input(INDICES)); + } + + template + bool DoRunWithType2() { auto& lengths = Input(LENGTHS); auto& indices = Input(INDICES); auto& dense = Input(DENSE); @@ -103,12 +160,11 @@ class BatchDenseToSparseOp : public Operator { CAFFE_ENFORCE_EQ(lengths.dim(), 1); CAFFE_ENFORCE_EQ(indices.dim(), 1); CAFFE_ENFORCE_EQ(dense.dim(), 2); - const int64_t* lengths_data = lengths.template data(); - const int64_t* indices_data = indices.template data(); + const TLen* lengths_data = lengths.template data(); + const TInd* indices_data = indices.template data(); const T* dense_data = dense.template data(); int64_t batch_size = lengths.numel(); - CAFFE_ENFORCE_EQ(batch_size, dense.size(0)); dense_last_dim_ = dense.size(1); vector output_shape = indices.sizes().vec(); @@ -127,17 +183,18 @@ class BatchDenseToSparseOp : public Operator { return true; } - private: - void FillInSparseValues( - const int64_t batch_size, - const int64_t indice_lengths, - const int64_t* lengths_data, - const int64_t* indices_data, - const T* dense_data, - T* output_data, - Context* context); + template + bool DoRunWithOtherType2() { + CAFFE_THROW( + "BatchDenseToSparse is not implemented on values of type ", + Input(DENSE).dtype().name(), + " with lengths of type ", + Input(LENGTHS).dtype().name(), + " and indices of type ", + Input(INDICES).dtype().name()); + } - int64_t dense_last_dim_; + int64_t dense_last_dim_{}; INPUT_TAGS(LENGTHS, INDICES, DENSE); // len_prefix_sum_ and len_prefix_tmp_ are buffers on the GPU. It is not used diff --git a/caffe2/python/operator_test/activation_ops_test.py b/caffe2/python/operator_test/activation_ops_test.py index 7e5c5f423606d2..47216d51500c68 100644 --- a/caffe2/python/operator_test/activation_ops_test.py +++ b/caffe2/python/operator_test/activation_ops_test.py @@ -243,7 +243,7 @@ def leaky_relu_ref(X): @given(X=hu.tensor(), fast_gelu=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_gelu(self, X, fast_gelu, gc, dc): op = core.CreateOperator( "Gelu", diff --git a/caffe2/python/operator_test/adadelta_test.py b/caffe2/python/operator_test/adadelta_test.py index 930f74ecd99eb8..6c40c379697fd2 100644 --- a/caffe2/python/operator_test/adadelta_test.py +++ b/caffe2/python/operator_test/adadelta_test.py @@ -53,7 +53,7 @@ def ref_adadelta(param_in, decay=hu.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_adadelta(self, inputs, lr, epsilon, decay, gc, dc): param, moment, moment_delta, grad = inputs moment = np.abs(moment) diff --git a/caffe2/python/operator_test/adagrad_test.py b/caffe2/python/operator_test/adagrad_test.py index 309c54a25cb1e9..3172026df1bf93 100644 --- a/caffe2/python/operator_test/adagrad_test.py +++ b/caffe2/python/operator_test/adagrad_test.py @@ -26,7 +26,7 @@ class TestAdagrad(serial.SerializedTestCase): weight_decay=st.sampled_from([0.0, 0.1]), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc): param, momentum, grad = inputs momentum = np.abs(momentum) @@ -98,7 +98,7 @@ def test_adagrad_output_effective_lr( ), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_adagrad_output_effective_lr_and_update(self, inputs, lr, epsilon, gc, dc): param, momentum, grad = inputs momentum = np.abs(momentum) @@ -158,7 +158,7 @@ def test_sparse_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc): ), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc): param, momentum = inputs grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32) @@ -190,7 +190,7 @@ def test_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc): # Suppress filter_too_much health check. # Likely caused by `assume` call falling through too often. - @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=1000) + @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=10000) @given( inputs=hu.tensors(n=3), lr=st.floats( diff --git a/caffe2/python/operator_test/assert_test.py b/caffe2/python/operator_test/assert_test.py index 2bbca5ab7376b2..eef33bc22bc077 100644 --- a/caffe2/python/operator_test/assert_test.py +++ b/caffe2/python/operator_test/assert_test.py @@ -14,7 +14,7 @@ class TestAssert(hu.HypothesisTestCase): dtype=st.sampled_from(['bool_', 'int32', 'int64']), shape=st.lists(elements=st.integers(1, 10), min_size=1, max_size=4), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_assert(self, dtype, shape, gc, dc): test_tensor = np.random.rand(*shape).astype(np.dtype(dtype)) diff --git a/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py b/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py index adfc735c66fde9..96132e8f5e93e5 100644 --- a/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py +++ b/caffe2/python/operator_test/batch_sparse_to_dense_op_test.py @@ -19,7 +19,7 @@ class TestBatchSparseToDense(serial.SerializedTestCase): default_value=st.floats(min_value=2.0, max_value=3.0), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=None) def test_batch_sparse_to_dense( self, batch_size, dense_last_dim, default_value, gc, dc ): @@ -69,13 +69,16 @@ def batch_sparse_to_dense_ref(L, I, V, S=None): self.assertDeviceChecks(dc, op2, [L, I, V, S], [0]) self.assertReferenceChecks(gc, op2, [L, I, V, S], batch_sparse_to_dense_ref) self.assertGradientChecks(gc, op2, [L, I, V, S], 2, [0]) + self.assertDeviceChecks(dc, op, [L.astype(np.int32), I, V], [0]) + self.assertReferenceChecks(gc, op, [L.astype(np.int32), I, V], batch_sparse_to_dense_ref) + self.assertGradientChecks(gc, op, [L.astype(np.int32), I, V], 2, [0]) @given( batch_size=st.integers(5, 10), dense_last_dim=st.integers(5, 10), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=None) def test_batch_dense_to_sparse(self, batch_size, dense_last_dim, gc, dc): L = np.random.randint(1, dense_last_dim + 1, size=(batch_size)) # The following logic ensure that indices in each batch will not be duplicated @@ -106,3 +109,6 @@ def batch_dense_to_sparse_ref(L, I, D): self.assertDeviceChecks(dc, op, [L, I, D], [0]) self.assertReferenceChecks(gc, op, [L, I, D], batch_dense_to_sparse_ref) self.assertGradientChecks(gc, op, [L, I, D], 2, [0]) + self.assertDeviceChecks(dc, op, [L.astype(np.int32), I, D], [0]) + self.assertReferenceChecks(gc, op, [L.astype(np.int32), I, D], batch_dense_to_sparse_ref) + self.assertGradientChecks(gc, op, [L.astype(np.int32), I, D], 2, [0]) diff --git a/caffe2/python/operator_test/bbox_transform_test.py b/caffe2/python/operator_test/bbox_transform_test.py index d2584f18af4078..adcc2f8723d234 100644 --- a/caffe2/python/operator_test/bbox_transform_test.py +++ b/caffe2/python/operator_test/bbox_transform_test.py @@ -214,7 +214,7 @@ class TestBBoxTransformOp(serial.SerializedTestCase): clip_angle_thresh=st.sampled_from([-1.0, 1.0]), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_bbox_transform( self, num_rois, @@ -282,7 +282,7 @@ def bbox_transform_ref(rois, deltas, im_info): clip_angle_thresh=st.sampled_from([-1.0, 1.0]), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_bbox_transform_batch( self, roi_counts, diff --git a/caffe2/python/operator_test/boolean_mask_test.py b/caffe2/python/operator_test/boolean_mask_test.py index 38fe43899990af..0ccdbd92851214 100644 --- a/caffe2/python/operator_test/boolean_mask_test.py +++ b/caffe2/python/operator_test/boolean_mask_test.py @@ -15,7 +15,7 @@ class TestBooleanMaskOp(serial.SerializedTestCase): max_len=100, elements=hu.floats(min_value=0.5, max_value=1.0)), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_boolean_mask_gradient(self, x, gc, dc): op = core.CreateOperator("BooleanMask", ["data", "mask"], @@ -30,7 +30,7 @@ def test_boolean_mask_gradient(self, x, gc, dc): max_len=5, elements=hu.floats(min_value=0.5, max_value=1.0)), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_boolean_mask(self, x, gc, dc): op = core.CreateOperator("BooleanMask", ["data", "mask"], diff --git a/caffe2/python/operator_test/box_with_nms_limit_op_test.py b/caffe2/python/operator_test/box_with_nms_limit_op_test.py index 3131316feefda1..e459edb57de387 100644 --- a/caffe2/python/operator_test/box_with_nms_limit_op_test.py +++ b/caffe2/python/operator_test/box_with_nms_limit_op_test.py @@ -83,7 +83,7 @@ def ref(*args, **kwargs): self.assertReferenceChecks(gc, op, [scores, boxes], ref) @given(**HU_CONFIG) - @settings(deadline=1000) + @settings(deadline=10000) def test_score_thresh(self, gc): in_centers = [(0, 0), (20, 20), (50, 50)] in_scores = [0.7, 0.85, 0.6] @@ -102,7 +102,7 @@ def ref(*args, **kwargs): self.assertReferenceChecks(gc, op, [scores, boxes], ref) @given(det_per_im=st.integers(1, 3), **HU_CONFIG) - @settings(deadline=1000) + @settings(deadline=10000) def test_detections_per_im(self, det_per_im, gc): in_centers = [(0, 0), (20, 20), (50, 50)] in_scores = [0.7, 0.85, 0.6] @@ -131,7 +131,7 @@ def ref(*args, **kwargs): output_classes_include_bg_cls=st.booleans(), **HU_CONFIG ) - @settings(deadline=1000) + @settings(deadline=10000) def test_multiclass( self, num_classes, diff --git a/caffe2/python/operator_test/clip_op_test.py b/caffe2/python/operator_test/clip_op_test.py index 3304121aab08c3..0e800dafe01a54 100644 --- a/caffe2/python/operator_test/clip_op_test.py +++ b/caffe2/python/operator_test/clip_op_test.py @@ -19,7 +19,7 @@ class TestClip(serial.SerializedTestCase): max_=st.floats(min_value=0, max_value=2), inplace=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_clip(self, X, min_, max_, inplace, gc, dc): # go away from the origin point to avoid kink problems if np.isscalar(X): diff --git a/caffe2/python/operator_test/clip_tensor_op_test.py b/caffe2/python/operator_test/clip_tensor_op_test.py index efc86815bc49d6..c90c38234c8e07 100644 --- a/caffe2/python/operator_test/clip_tensor_op_test.py +++ b/caffe2/python/operator_test/clip_tensor_op_test.py @@ -19,7 +19,7 @@ class TestClipTensorByScalingOp(serial.SerializedTestCase): use_additional_threshold=st.booleans(), inplace=st.booleans(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_clip_tensor_by_scaling(self, n, d, threshold, additional_threshold, use_additional_threshold, inplace, gc, dc): diff --git a/caffe2/python/operator_test/conv_test.py b/caffe2/python/operator_test/conv_test.py index e600aa2c9ee96b..23217b15b82dea 100644 --- a/caffe2/python/operator_test/conv_test.py +++ b/caffe2/python/operator_test/conv_test.py @@ -164,7 +164,7 @@ def test_convolution_separate_stride_pad_gradients( use_bias=st.booleans(), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=None) def test_convolution_separate_stride_pad_layout( self, op_type, @@ -761,7 +761,7 @@ def canonical(o): engine=st.sampled_from(["CUDNN", ""]), **hu.gcs_no_hip ) - @settings(deadline=1000) + @settings(deadline=None) def test_convolution_sync(self, net_type, num_workers, engine, gc, dc): m = ModelHelper(name="test_model") n = 1 diff --git a/caffe2/python/operator_test/crf_test.py b/caffe2/python/operator_test/crf_test.py index 4d7b90c431a6a8..a4447fa3f36404 100644 --- a/caffe2/python/operator_test/crf_test.py +++ b/caffe2/python/operator_test/crf_test.py @@ -15,7 +15,7 @@ class TestCRFOp(hu.HypothesisTestCase): @given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15)) - @settings(deadline=1000) + @settings(deadline=10000) def test_crf_with_loss_op(self, num_tags, num_words): model = ModelHelper(name='external') embeddings_dim = 200 diff --git a/caffe2/python/operator_test/dropout_op_test.py b/caffe2/python/operator_test/dropout_op_test.py index 84c2f7e35f5656..d3a5c831d875d3 100644 --- a/caffe2/python/operator_test/dropout_op_test.py +++ b/caffe2/python/operator_test/dropout_op_test.py @@ -48,7 +48,7 @@ def reference_dropout_test(x): output_mask=st.booleans(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_dropout_ratio0(self, X, in_place, output_mask, engine, gc, dc): """Test with ratio=0 for a deterministic reference impl.""" # TODO(lukeyeager): enable this path when the op is fixed diff --git a/caffe2/python/operator_test/elementwise_op_broadcast_test.py b/caffe2/python/operator_test/elementwise_op_broadcast_test.py index 605c1d74127104..bd19ebc6ed974c 100644 --- a/caffe2/python/operator_test/elementwise_op_broadcast_test.py +++ b/caffe2/python/operator_test/elementwise_op_broadcast_test.py @@ -75,22 +75,22 @@ def __test_binary_op(self, gc, dc, caffe2_op, op_function): self.assertGradientChecks(gc, op, [X, Y], 1, [0]) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_Add(self, gc, dc): self.__test_binary_op(gc, dc, "Add", operator.add) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_Mul(self, gc, dc): self.__test_binary_op(gc, dc, "Mul", operator.mul) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_Sub(self, gc, dc): self.__test_binary_op(gc, dc, "Sub", operator.sub) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_broadcast_powt(self, gc, dc): np.random.seed(101) diff --git a/caffe2/python/operator_test/elementwise_ops_test.py b/caffe2/python/operator_test/elementwise_ops_test.py index 922e4554e9a805..130ebade010ba7 100644 --- a/caffe2/python/operator_test/elementwise_ops_test.py +++ b/caffe2/python/operator_test/elementwise_ops_test.py @@ -59,7 +59,7 @@ def exp_ref(X): @given(n=st.integers(0, 6), m=st.integers(4, 6), seed=st.integers(0, 1000), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_log(self, n, m, gc, dc, seed): np.random.seed(seed) X = np.random.rand(n, m).astype(np.float32) + 1.0 @@ -326,7 +326,7 @@ def swish(X): @given(n=st.integers(0, 6), m=st.integers(4, 6), seed=st.integers(0, 1000), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_swish_gradient_inplace(self, n, m, gc, dc, seed): np.random.seed(seed) @@ -354,7 +354,7 @@ def swish_gradient(X, Y, dY): @given(X=hu.tensor(dtype=np.float32), inplace=st.booleans(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_sigmoid(self, X, inplace, engine, gc, dc): op = core.CreateOperator( "Sigmoid", diff --git a/caffe2/python/operator_test/erf_op_test.py b/caffe2/python/operator_test/erf_op_test.py index 64714db4315cf1..a4ed0d5fb23e95 100644 --- a/caffe2/python/operator_test/erf_op_test.py +++ b/caffe2/python/operator_test/erf_op_test.py @@ -18,7 +18,7 @@ class TestErfOp(serial.SerializedTestCase): @given( X=hu.tensor(elements=hu.floats(min_value=-0.7, max_value=0.7)), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_erf(self, X, gc, dc): op = core.CreateOperator('Erf', ["X"], ["Y"]) self.assertReferenceChecks(gc, op, [X], lambda x: (np.vectorize(math.erf)(X),)) diff --git a/caffe2/python/operator_test/expand_op_test.py b/caffe2/python/operator_test/expand_op_test.py index aba2c1106da37c..bd608f6fcc24c1 100644 --- a/caffe2/python/operator_test/expand_op_test.py +++ b/caffe2/python/operator_test/expand_op_test.py @@ -59,7 +59,7 @@ def test_expand_nonrand_shape1(self, X, gc, dc): np.ones([1, 4, 1, 2]), np.ones([4, 1, 2])]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_expand_nonrand_shape2(self, X, gc, dc): self._run_expand_op_test(X, [4, 1, 2, 2], gc, dc) self._run_expand_op_test(X, [4, -1, 2, 2], gc, dc) diff --git a/caffe2/python/operator_test/fc_operator_test.py b/caffe2/python/operator_test/fc_operator_test.py index 1e8b5522053d36..bd203b7c84a6ac 100644 --- a/caffe2/python/operator_test/fc_operator_test.py +++ b/caffe2/python/operator_test/fc_operator_test.py @@ -61,8 +61,8 @@ def fc_transposed_op(X, W, b): op.arg.extend([a]) # Check against numpy reference - # ReferenceChecks is flaky on rocm with threshold of 1e-4 for fp16. Relaxing to 1e-3. - threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP and dtype == np.float16) else 1e-4 + # ReferenceChecks is flaky, Relaxing to 1e-3. + threshold = 1e-3 self.assertReferenceChecks( device_option=gc, op=op, diff --git a/caffe2/python/operator_test/filler_ops_test.py b/caffe2/python/operator_test/filler_ops_test.py index e080dde3eb5fb0..442f5866cb0916 100644 --- a/caffe2/python/operator_test/filler_ops_test.py +++ b/caffe2/python/operator_test/filler_ops_test.py @@ -22,7 +22,7 @@ def _fill_diagonal(shape, value): class TestFillerOperator(serial.SerializedTestCase): @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_shape_error(self, gc, dc): op = core.CreateOperator( 'GaussianFill', @@ -77,7 +77,7 @@ def test_int64_shape(self, gc, dc): b=st.integers(min_value=0, max_value=100), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_uniform_int_fill_op_blob_input(self, shape, a, b, gc, dc): net = core.Net('test_net') diff --git a/caffe2/python/operator_test/flexible_top_k_test.py b/caffe2/python/operator_test/flexible_top_k_test.py index 3e0e5722b0ceb5..0cccabb5f2e951 100644 --- a/caffe2/python/operator_test/flexible_top_k_test.py +++ b/caffe2/python/operator_test/flexible_top_k_test.py @@ -40,7 +40,7 @@ def flexible_top_k_ref(self, X, k): return (values_ref, indices_ref) @given(X=hu.tensor(min_dim=2), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_flexible_top_k(self, X, gc, dc): X = X.astype(dtype=np.float32) k_shape = (int(X.size / X.shape[-1]), ) diff --git a/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py b/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py index b7cb5f68351f57..d2e794da065140 100644 --- a/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py +++ b/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py @@ -205,7 +205,7 @@ def ErrorThresholdRow(X, bit_rate): class TestNBitFakeFused(hu.HypothesisTestCase): @given(bit_rate=st.sampled_from([2, 4])) - @settings(deadline=1000) + @settings(deadline=10000) def testNBit(self, bit_rate): # uncomment for debugging # np.random.seed(0) diff --git a/caffe2/python/operator_test/gather_ops_test.py b/caffe2/python/operator_test/gather_ops_test.py index fc23be13fdae76..b0d64506e4c792 100644 --- a/caffe2/python/operator_test/gather_ops_test.py +++ b/caffe2/python/operator_test/gather_ops_test.py @@ -209,7 +209,7 @@ class TestGatherFused8BitRowwise(hu.HypothesisTestCase): cols_num=st.integers(1, 128), index_num=st.integers(0, 5000), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_batch_gather_ops(self, rows_num, cols_num, index_num, gc, dc): data = np.random.random((rows_num, cols_num)).astype(np.float32) ind = np.random.randint(rows_num, size=(index_num, )).astype('int32') diff --git a/caffe2/python/operator_test/gather_ranges_op_test.py b/caffe2/python/operator_test/gather_ranges_op_test.py index c0d73af33601ea..b6ec8823f4dd22 100644 --- a/caffe2/python/operator_test/gather_ranges_op_test.py +++ b/caffe2/python/operator_test/gather_ranges_op_test.py @@ -166,7 +166,7 @@ def gather_ranges_to_dense_with_key(data, ranges, key, lengths): class TestGatherRanges(serial.SerializedTestCase): @given(boarders_and_data=batched_boarders_and_data(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_gather_ranges(self, boarders_and_data, gc, dc): boarders, data = boarders_and_data @@ -187,7 +187,7 @@ def boarders_to_range(boarders): ) @given(tensor_splits=_tensor_splits(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_gather_ranges_split(self, tensor_splits, gc, dc): data, ranges, lengths, _ = tensor_splits diff --git a/caffe2/python/operator_test/instance_norm_test.py b/caffe2/python/operator_test/instance_norm_test.py index efce9d7001fe7b..d97385cbe2155d 100644 --- a/caffe2/python/operator_test/instance_norm_test.py +++ b/caffe2/python/operator_test/instance_norm_test.py @@ -60,7 +60,7 @@ def _feed_inputs(self, input_blobs, device_option): store_mean=st.booleans(), seed=st.integers(0, 1000), store_inv_stdev=st.booleans()) - @settings(deadline=1000) + @settings(deadline=10000) def test_instance_norm_gradients( self, gc, dc, N, C, H, W, order, store_mean, store_inv_stdev, epsilon, seed): diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index 67d7f14bd3367d..32a2511e3e8e39 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -322,7 +322,7 @@ def test_layer_norm_op_pytorch_cuda(self, X, eps, elementwise_affine): eps=st.floats(1e-5, 1e-3), elementwise_affine=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_layer_norm_op_jit(self, X, eps, elementwise_affine, gc, dc): @torch.jit.script def jit_layer_norm( diff --git a/caffe2/python/operator_test/length_split_op_test.py b/caffe2/python/operator_test/length_split_op_test.py index 28d7134ac5e8e7..3f20ff1f458524 100644 --- a/caffe2/python/operator_test/length_split_op_test.py +++ b/caffe2/python/operator_test/length_split_op_test.py @@ -28,7 +28,7 @@ def _length_split_op_ref(self, input_lengths, n_split_array): return [np.array(output).astype(np.int32)] @given(**hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_length_split_edge(self, gc, dc): input_lengths = np.array([3, 4, 5]).astype(np.int32) n_split_ = np.array([5]).astype(np.int32) diff --git a/caffe2/python/operator_test/locally_connected_op_test.py b/caffe2/python/operator_test/locally_connected_op_test.py index 2adc253f4d8809..445c3641573f92 100644 --- a/caffe2/python/operator_test/locally_connected_op_test.py +++ b/caffe2/python/operator_test/locally_connected_op_test.py @@ -103,7 +103,7 @@ def lc_2d_nhwc(X, W, b=None): op_name=st.sampled_from(["LC", "LC1D"]), use_bias=st.booleans(), **hu.gcs) - @settings(deadline=5000) + @settings(deadline=None) # Increased timeout from 1 second to 5 for ROCM def test_lc_1d(self, N, C, size, M, kernel, op_name, use_bias, gc, dc): if size < kernel: @@ -163,7 +163,7 @@ def conv(n, m, yl): op_name=st.sampled_from(["LC", "LC3D"]), use_bias=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_lc_3d(self, N, C, T, H, W, M, kernel, op_name, use_bias, gc, dc): if T < kernel: kernel = T diff --git a/caffe2/python/operator_test/lpnorm_op_test.py b/caffe2/python/operator_test/lpnorm_op_test.py index 3a58cbe6d9604e..e7ab634d0e7c93 100644 --- a/caffe2/python/operator_test/lpnorm_op_test.py +++ b/caffe2/python/operator_test/lpnorm_op_test.py @@ -16,7 +16,7 @@ class LpnormTest(hu.HypothesisTestCase): max_dim=3, dtype=np.float32), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_Lp_Norm(self, inputs, gc, dc): X = inputs[0] # avoid kinks by moving away from 0 diff --git a/caffe2/python/operator_test/margin_ranking_criterion_op_test.py b/caffe2/python/operator_test/margin_ranking_criterion_op_test.py index e28dd1ce28f8dd..a91de60a8c1967 100644 --- a/caffe2/python/operator_test/margin_ranking_criterion_op_test.py +++ b/caffe2/python/operator_test/margin_ranking_criterion_op_test.py @@ -17,7 +17,7 @@ class TestMarginRankingCriterion(serial.SerializedTestCase): seed=st.integers(min_value=0, max_value=65535), margin=st.floats(min_value=-0.5, max_value=0.5), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_margin_ranking_criterion(self, N, seed, margin, gc, dc): np.random.seed(seed) X1 = np.random.randn(N).astype(np.float32) diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py index 8b4001a574ac60..067eeabbe2d980 100644 --- a/caffe2/python/operator_test/matmul_op_test.py +++ b/caffe2/python/operator_test/matmul_op_test.py @@ -60,7 +60,7 @@ def matmul_ref(X, Y, trans_a, trans_b): trans_b=st.booleans(), **hu.gcs ) - @settings(deadline=1000) + @settings(deadline=10000) def test_matmul_axis( self, M, K, N, axis_a, axis_b, trans_a, trans_b, gc, dc ): diff --git a/caffe2/python/operator_test/one_hot_ops_test.py b/caffe2/python/operator_test/one_hot_ops_test.py index 593d5b5aa58c12..e23e04434ab3c7 100644 --- a/caffe2/python/operator_test/one_hot_ops_test.py +++ b/caffe2/python/operator_test/one_hot_ops_test.py @@ -63,7 +63,7 @@ def ref(x, lens, vals): elements=st.integers(min_value=-5, max_value=5)), seed=st.integers(min_value=0, max_value=1000), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_batch_bucketized_one_hot(self, x, seed, gc, dc): np.random.seed(seed) d = x.shape[1] diff --git a/caffe2/python/operator_test/pooling_test.py b/caffe2/python/operator_test/pooling_test.py index 7ef98249bd7930..2954face6b85ef 100644 --- a/caffe2/python/operator_test/pooling_test.py +++ b/caffe2/python/operator_test/pooling_test.py @@ -90,7 +90,7 @@ def test_pooling_big_batch(self, gc, dc): op_type=st.sampled_from(["MaxPool", "AveragePool", "MaxPool1D", "AveragePool1D"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_pooling_1d(self, stride, pad, kernel, size, input_channels, batch_size, order, op_type, gc, dc): assume(pad < kernel) diff --git a/caffe2/python/operator_test/python_op_test.py b/caffe2/python/operator_test/python_op_test.py index b071070151d16b..8f41815585dc65 100644 --- a/caffe2/python/operator_test/python_op_test.py +++ b/caffe2/python/operator_test/python_op_test.py @@ -14,7 +14,7 @@ class PythonOpTest(hu.HypothesisTestCase): @given(x=hu.tensor(), n=st.integers(min_value=1, max_value=20), w=st.integers(min_value=1, max_value=20)) - @settings(deadline=1000) + @settings(deadline=10000) def test_simple_python_op(self, x, n, w): def g(input_, output): output[...] = input_ diff --git a/caffe2/python/operator_test/reduce_ops_test.py b/caffe2/python/operator_test/reduce_ops_test.py index 7b79b3b81aed19..299b373e509daf 100644 --- a/caffe2/python/operator_test/reduce_ops_test.py +++ b/caffe2/python/operator_test/reduce_ops_test.py @@ -96,7 +96,7 @@ def test_reduce_mean(self, X, keepdims, num_axes, gc, dc): @given(n=st.integers(1, 3), m=st.integers(1, 3), k=st.integers(1, 3), keepdims=st.booleans(), num_axes=st.integers(1, 3), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_reduce_l1(self, n, m, k, keepdims, num_axes, gc, dc): X = np.arange(n * m * k, dtype=np.float32) - 0.5 np.random.shuffle(X) @@ -253,7 +253,7 @@ def ref_sum(X): np.testing.assert_allclose(output, ref_sum(X)[0], atol=1e-3) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_reduce_front_sum_with_length(self, dc, gc): num_reduce_dim = 1 X = np.random.rand(2, 3, 4, 5).astype(np.float32) @@ -286,7 +286,7 @@ def ref_mean(X): "ReduceFrontMeanGradient", X, ref_mean, num_reduce_dim) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_reduce_front_mean_with_length(self, dc, gc): num_reduce_dim = 1 X = np.random.rand(2, 3, 4, 5).astype(np.float32) @@ -411,7 +411,7 @@ def ref_mean(X): "ReduceBackMeanGradient", X, ref_mean, num_reduce_dim) @given(**hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_reduce_back_mean_with_length(self, dc, gc): num_reduce_dim = 1 X = np.random.rand(2, 3, 4, 5).astype(np.float32) diff --git a/caffe2/python/operator_test/selu_op_test.py b/caffe2/python/operator_test/selu_op_test.py index 4dd2fa1848bfaa..73cb0736dceee8 100644 --- a/caffe2/python/operator_test/selu_op_test.py +++ b/caffe2/python/operator_test/selu_op_test.py @@ -33,7 +33,7 @@ def test_selu_1(self, X, gc, dc, engine): @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_selu_2(self, X, gc, dc, engine): alpha = 1.6732 scale = 1.0507 @@ -50,7 +50,7 @@ def test_selu_2(self, X, gc, dc, engine): @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_selu_3(self, X, gc, dc, engine): alpha = 1.3 scale = 1.1 diff --git a/caffe2/python/operator_test/sequence_ops_test.py b/caffe2/python/operator_test/sequence_ops_test.py index 65c0669abfb005..524d3c8b4149f3 100644 --- a/caffe2/python/operator_test/sequence_ops_test.py +++ b/caffe2/python/operator_test/sequence_ops_test.py @@ -106,7 +106,7 @@ class TestSequenceOps(serial.SerializedTestCase): args=_gen_test_add_padding(with_pad_data=True), ret_lengths=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_add_padding( self, start_pad_width, end_pad_width, args, ret_lengths, gc, dc ): @@ -278,7 +278,7 @@ def op_ref(data, indices): min_size=0, max_size=10), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_find_duplicate_elements(self, elements, gc, dc): mapping = { 0: "a", diff --git a/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py b/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py index 6e8cae62dbff24..03b50bfc952db2 100644 --- a/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py +++ b/caffe2/python/operator_test/sinusoid_position_encoding_op_test.py @@ -33,7 +33,7 @@ class TestSinusoidPositionEncodingOp(serial.SerializedTestCase): amplitude=st.floats(MIN_TEST_AMPLITUDE, MAX_TEST_AMPLITUDE), **hu.gcs_cpu_only ) - @settings(deadline=1000) + @settings(deadline=10000) def test_sinusoid_embedding( self, positions_vec, embedding_size, batch_size, alpha, amplitude, gc, dc ): diff --git a/caffe2/python/operator_test/softmax_ops_test.py b/caffe2/python/operator_test/softmax_ops_test.py index 533d575ee59fc7..8ec92ae1af9e52 100644 --- a/caffe2/python/operator_test/softmax_ops_test.py +++ b/caffe2/python/operator_test/softmax_ops_test.py @@ -143,7 +143,7 @@ def label_softmax(X): @given(n=st.integers(2, 10), D=st.integers(4, 16), only_loss=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_softmax_with_loss(self, n, D, gc, only_loss, dc): # n = number of examples, D = |labels| # Initialize X and add 1e-2 for numerical stability @@ -301,7 +301,7 @@ def label_softmax_crossent(X, label): ) @given(n=st.integers(2, 10), D=st.integers(4, 16), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_softmax_with_loss_label_prob(self, n, D, gc, dc): # n = number of examples, D = |labels| # Initialize X and add 1e-2 for numerical stability @@ -358,7 +358,7 @@ def label_softmax_crossent(X, label): D=st.integers(4, 16), only_loss=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_softmax_with_loss_weighted(self, n, D, only_loss, gc, dc): # n = number of examples, D = |labels| # Initialize X and add 1e-2 for numerical stability diff --git a/caffe2/python/operator_test/softplus_op_test.py b/caffe2/python/operator_test/softplus_op_test.py index dd183b774f924b..f8ca1817176ebb 100644 --- a/caffe2/python/operator_test/softplus_op_test.py +++ b/caffe2/python/operator_test/softplus_op_test.py @@ -14,7 +14,7 @@ class TestSoftplus(hu.HypothesisTestCase): @given(X=hu.tensor(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_softplus(self, X, gc, dc): op = core.CreateOperator("Softplus", ["X"], ["Y"]) self.assertDeviceChecks(dc, op, [X], [0]) diff --git a/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py b/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py index 41ec8808bb6acb..267babf2145f7e 100644 --- a/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py +++ b/caffe2/python/operator_test/sparse_to_dense_mask_op_test.py @@ -14,7 +14,7 @@ class TestFcOperator(hu.HypothesisTestCase): @given(n=st.integers(1, 10), k=st.integers(1, 5), use_length=st.booleans(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_to_dense_mask(self, n, k, use_length, gc, dc): lengths = np.random.randint(k, size=n).astype(np.int32) + 1 N = sum(lengths) @@ -47,7 +47,7 @@ def test_sparse_to_dense_mask(self, n, k, use_length, gc, dc): @given(n=st.integers(1, 10), k=st.integers(1, 5), use_length=st.booleans(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_to_dense_mask_with_int64(self, n, k, use_length, gc, dc): lengths = np.random.randint(k, size=n).astype(np.int32) + 1 N = sum(lengths) diff --git a/caffe2/python/operator_test/string_ops_test.py b/caffe2/python/operator_test/string_ops_test.py index a0c56a68666645..aa706ad73d7c17 100644 --- a/caffe2/python/operator_test/string_ops_test.py +++ b/caffe2/python/operator_test/string_ops_test.py @@ -20,7 +20,7 @@ def _string_lists(alphabet=None): class TestStringOps(serial.SerializedTestCase): @given(strings=_string_lists()) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_prefix(self, strings): length = 3 # although we are utf-8 encoding below to avoid python exceptions, @@ -48,7 +48,7 @@ def string_prefix_ref(strings): string_prefix_ref) @given(strings=_string_lists()) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_suffix(self, strings): length = 3 strings = np.array( @@ -72,7 +72,7 @@ def string_suffix_ref(strings): string_suffix_ref) @given(strings=st.text(alphabet=['a', 'b'])) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_starts_with(self, strings): prefix = 'a' strings = np.array( @@ -96,7 +96,7 @@ def string_starts_with_ref(strings): string_starts_with_ref) @given(strings=st.text(alphabet=['a', 'b'])) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_ends_with(self, strings): suffix = 'a' strings = np.array( @@ -120,7 +120,7 @@ def string_ends_with_ref(strings): string_ends_with_ref) @given(strings=st.text(alphabet=['a', 'b'])) - @settings(deadline=1000) + @settings(deadline=10000) def test_string_equals(self, strings): text = "" if strings: diff --git a/caffe2/python/operator_test/top_k_test.py b/caffe2/python/operator_test/top_k_test.py index fa628456c3a416..035b1fb3d0992b 100644 --- a/caffe2/python/operator_test/top_k_test.py +++ b/caffe2/python/operator_test/top_k_test.py @@ -140,7 +140,7 @@ def bind_ref(X_loc): @given(bs=st.integers(1, 3), n=st.integers(100, 10000), flatten_indices=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_top_k_4(self, bs, n, flatten_indices, gc, dc): k = np.random.randint(n // 3, 3 * n // 4) X = np.random.rand(bs, n).astype(dtype=np.float32) @@ -177,7 +177,7 @@ def bind_ref(X_loc): @given(bs=st.integers(1, 3), n=st.integers(1, 5000), flatten_indices=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=10000) def test_top_k_6(self, bs, n, flatten_indices, gc, dc): k = n X = np.random.rand(bs, n).astype(dtype=np.float32) diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index e568f8bdff74d3..f99a61688de6e0 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -991,7 +991,7 @@ def test_gather_ranges_to_dense_op(self): np.testing.assert_array_almost_equal(ref_outputs[i], outputs[i].numpy()) @given(lengths_0=st.integers(1, 10), lengths_1=st.integers(1, 10)) - @settings(deadline=1000) + @settings(deadline=10000) def test_merge_id_lists(self, lengths_0, lengths_1): def _merge_id_lists(lengths, values): ref_op = core.CreateOperator( diff --git a/caffe2/python/operator_test/utility_ops_test.py b/caffe2/python/operator_test/utility_ops_test.py index aeefbf596afeb5..187328f9e48416 100644 --- a/caffe2/python/operator_test/utility_ops_test.py +++ b/caffe2/python/operator_test/utility_ops_test.py @@ -332,7 +332,7 @@ def sum_op_ref(*args): ) ), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_lengths_gather(self, inputs, gc, dc): items = inputs[0] lengths = inputs[1] @@ -359,7 +359,7 @@ def lengths_gather_op(items, lengths, indices): @given( inputs=hu.lengths_tensor(), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_lengths_to_ranges(self, inputs, gc, dc): _, lengths = inputs diff --git a/caffe2/python/operator_test/weighted_sum_test.py b/caffe2/python/operator_test/weighted_sum_test.py index 2c7dffe92672d4..fbbe2a6bf6d8f7 100644 --- a/caffe2/python/operator_test/weighted_sum_test.py +++ b/caffe2/python/operator_test/weighted_sum_test.py @@ -61,7 +61,7 @@ def weighted_sum_op_ref(*args): @given(n=st.integers(1, 8), m=st.integers(1, 10), d=st.integers(1, 4), grad_on_w=st.booleans(), seed=st.integers(min_value=0, max_value=65535), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_weighted_sum_grad( self, n, m, d, grad_on_w, seed, gc, dc): input_names = [] diff --git a/caffe2/python/operator_test/wngrad_test.py b/caffe2/python/operator_test/wngrad_test.py index 48fe0f94731e89..0a1f0405e92aec 100644 --- a/caffe2/python/operator_test/wngrad_test.py +++ b/caffe2/python/operator_test/wngrad_test.py @@ -113,7 +113,7 @@ def test_wngrad_dense_base(self, inputs, seq_b, lr, epsilon, gc, dc): epsilon=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_wngrad_dense_output_effective_lr(self, inputs, seq_b, lr, epsilon, gc, dc): param, grad = inputs @@ -142,7 +142,7 @@ def test_wngrad_dense_output_effective_lr(self, inputs, seq_b, epsilon=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_wngrad_dense_output_effective_lr_and_update( self, inputs, seq_b, lr, epsilon, gc, dc): param, grad = inputs @@ -165,7 +165,7 @@ def test_wngrad_dense_output_effective_lr_and_update( # Suppress filter_too_much health check. # Likely caused by `assume` call falling through too often. - @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=1000) + @settings(suppress_health_check=[HealthCheck.filter_too_much], deadline=10000) @given(inputs=hu.tensors(n=2), seq_b=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), @@ -186,7 +186,7 @@ def test_sparse_wngrad(self, inputs, seq_b, lr, epsilon, gc, dc): epsilon=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False), **hu.gcs_cpu_only) - @settings(deadline=1000) + @settings(deadline=10000) def test_sparse_wngrad_empty(self, inputs, seq_b, lr, epsilon, gc, dc): param = inputs[0] seq_b = np.array([seq_b, ], dtype=np.float32) diff --git a/caffe2/utils/threadpool/pthreadpool-cpp.cc b/caffe2/utils/threadpool/pthreadpool-cpp.cc index d18206c99ce8bd..6737935d13abd2 100644 --- a/caffe2/utils/threadpool/pthreadpool-cpp.cc +++ b/caffe2/utils/threadpool/pthreadpool-cpp.cc @@ -45,8 +45,17 @@ void PThreadPool::set_thread_count(const size_t thread_count) { void PThreadPool::run( const std::function& fn, const size_t range) { + // Run on same thread if _NoPThreadPoolGuard guard is enabled + if (caffe2::_NoPThreadPoolGuard::is_enabled()) { + for (size_t i = 0; i < range; ++i) { + fn(i); + } + return; + } + std::lock_guard lock{mutex_}; + TORCH_INTERNAL_ASSERT(!caffe2::_NoPThreadPoolGuard::is_enabled(), "Inside a threadpool guard!"); TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!"); struct Context final { diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index c7fe9b7d4bdea4..6d9c3ac3ab9084 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -999,24 +999,20 @@ if(BUILD_PYTHON) endif() # ---[ pybind11 -if(NOT pybind11_PREFER_third_party) +if(USE_SYSTEM_BIND11) find_package(pybind11 CONFIG) if(NOT pybind11_FOUND) find_package(pybind11) endif() -endif() - -if(pybind11_FOUND) - message(STATUS "System pybind11 found") + if(NOT pybind11_FOUND) + message(FATAL "Cannot find system pybind11") + endif() else() message(STATUS "Using third_party/pybind11.") set(pybind11_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/../third_party/pybind11/include) install(DIRECTORY ${pybind11_INCLUDE_DIRS} DESTINATION ${CMAKE_INSTALL_PREFIX} FILES_MATCHING PATTERN "*.h") - set(pybind11_PREFER_third_party ON CACHE BOOL - "Use the third_party/pybind11 submodule, instead of looking for system - installation of pybind11") endif() message(STATUS "pybind11 include dirs: " "${pybind11_INCLUDE_DIRS}") include_directories(SYSTEM ${pybind11_INCLUDE_DIRS}) diff --git a/cmake/Modules/FindBLIS.cmake b/cmake/Modules/FindBLIS.cmake index 8892f5358fdc0b..3be90c19ea180f 100644 --- a/cmake/Modules/FindBLIS.cmake +++ b/cmake/Modules/FindBLIS.cmake @@ -17,6 +17,7 @@ SET(BLIS_INCLUDE_SEARCH_PATHS /opt/blis/include $ENV{BLIS_HOME} $ENV{BLIS_HOME}/include + $ENV{BLIS_HOME}/include/blis ) SET(BLIS_LIB_SEARCH_PATHS diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 8ff39bc458b33e..85f35d05dc821f 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -91,7 +91,7 @@ function(caffe2_print_configuration_summary) get_target_property(__tmp caffe2::curand IMPORTED_LOCATION) message(STATUS " curand library : ${__tmp}") if(${USE_CUDNN}) - get_target_property(__tmp caffe2::cudnn IMPORTED_LOCATION) + get_target_property(__tmp caffe2::cudnn INTERFACE_LINK_LIBRARIES) message(STATUS " cuDNN library : ${__tmp}") endif() get_target_property(__tmp caffe2::nvrtc IMPORTED_LOCATION) diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index c154aed1f6b1ce..a5cc3602b735f8 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -272,20 +272,52 @@ else() ${LIBNVTOOLSEXT}) endif() +# cublas. CUDA_CUBLAS_LIBRARIES is actually a list, so we will make an +# interface library similar to cudart. +add_library(caffe2::cublas INTERFACE IMPORTED) +if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) + set_property( + TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES + "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a") + if(CUDA_VERSION VERSION_GREATER_EQUAL 10.1) + set_property( + TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES + "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublasLt_static.a") + # Add explicit dependency to cudart_static to fix + # libcublasLt_static.a.o): undefined reference to symbol 'cudaStreamWaitEvent' + # error adding symbols: DSO missing from command line + set_property( + TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES + "${CUDA_cudart_static_LIBRARY}" rt dl) + endif() +else() + set_property( + TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES + ${CUDA_CUBLAS_LIBRARIES}) +endif() +set_property( + TARGET caffe2::cublas PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${CUDA_INCLUDE_DIRS}) + # cudnn # static linking is handled by USE_STATIC_CUDNN environment variable if(CAFFE2_USE_CUDNN) - add_library(caffe2::cudnn UNKNOWN IMPORTED) + add_library(caffe2::cudnn INTERFACE IMPORTED) set_property( - TARGET caffe2::cudnn PROPERTY IMPORTED_LOCATION - ${CUDNN_LIBRARY_PATH}) + TARGET caffe2::cudnn PROPERTY INTERFACE_LINK_LIBRARIES + ${CUDNN_LIBRARY_PATH}) set_property( - TARGET caffe2::cudnn PROPERTY INTERFACE_INCLUDE_DIRECTORIES - ${CUDNN_INCLUDE_PATH}) + TARGET caffe2::cudnn PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${CUDNN_INCLUDE_PATH}) if(CUDNN_STATIC AND NOT WIN32) set_property( - TARGET caffe2::cudnn PROPERTY INTERFACE_LINK_LIBRARIES - "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libculibos.a" dl) + TARGET caffe2::cudnn APPEND PROPERTY INTERFACE_LINK_LIBRARIES + "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libculibos.a" dl) + # Add explicit dependency on cublas to cudnn + get_target_property(__tmp caffe2::cublas INTERFACE_LINK_LIBRARIES) + set_property( + TARGET caffe2::cudnn APPEND PROPERTY INTERFACE_LINK_LIBRARIES + "${__tmp}") # Lines below use target_link_libraries because we support cmake 3.5+. # For cmake 3.13+, target_link_options to set INTERFACE_LINK_OPTIONS would be better. # https://cmake.org/cmake/help/v3.5/command/target_link_libraries.html warns @@ -346,33 +378,6 @@ if(CAFFE2_USE_TENSORRT) ${TENSORRT_INCLUDE_DIR}) endif() -# cublas. CUDA_CUBLAS_LIBRARIES is actually a list, so we will make an -# interface library similar to cudart. -add_library(caffe2::cublas INTERFACE IMPORTED) -if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) - set_property( - TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES - "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a") - if(CUDA_VERSION VERSION_GREATER_EQUAL 10.1) - set_property( - TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES - "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublasLt_static.a") - # Add explicit dependency to cudart_static to fix - # libcublasLt_static.a.o): undefined reference to symbol 'cudaStreamWaitEvent' - # error adding symbols: DSO missing from command line - set_property( - TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES - "${CUDA_cudart_static_LIBRARY}" rt dl) - endif() -else() - set_property( - TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES - ${CUDA_CUBLAS_LIBRARIES}) -endif() -set_property( - TARGET caffe2::cublas PROPERTY INTERFACE_INCLUDE_DIRECTORIES - ${CUDA_INCLUDE_DIRS}) - # nvrtc add_library(caffe2::nvrtc UNKNOWN IMPORTED) set_property( diff --git a/docs/cpp/source/notes/inference_mode.rst b/docs/cpp/source/notes/inference_mode.rst index 2ceb2dcdb762fb..efb1b9de2d1a76 100644 --- a/docs/cpp/source/notes/inference_mode.rst +++ b/docs/cpp/source/notes/inference_mode.rst @@ -30,8 +30,6 @@ Inside an ``InferenceMode`` block, we make the following performance guarantees: - Inplace operations on inference tensors are guaranteed not to do a version bump. For more implementation details of ``InferenceMode`` please see the `RFC-0011-InferenceMode `_. -Currently this guard is only available in C++ frontend, adding python frontend support -is tracked in #56608. Migration guide from ``AutoNonVariableTypeMode`` ------------------------------------------------ diff --git a/docs/source/amp.rst b/docs/source/amp.rst index bf93d95ecbc3e0..4ebd1236100691 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -142,6 +142,7 @@ Ops that can autocast to ``float32`` ``exp``, ``expm1``, ``gelu``, +``grid_sample``, ``group_norm``, ``hinge_embedding_loss``, ``kl_div``, diff --git a/docs/source/autograd.rst b/docs/source/autograd.rst index 5bc588b0fa8b87..566808036701b6 100644 --- a/docs/source/autograd.rst +++ b/docs/source/autograd.rst @@ -50,6 +50,10 @@ you can use it as ``functional.jacobian(lambda x: f(x, constant, flag=flag), inp Locally disabling gradient computation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +See :ref:`locally-disable-grad-doc` for more information on the differences +between no-grad and inference mode as well as other related mechanisms that +may be confused with the two. + .. autosummary:: :toctree: generated :nosignatures: diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 9d2fcc042d6f50..b9889127aac075 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -419,6 +419,8 @@ Collective functions .. autofunction:: barrier +.. autofunction:: monitored_barrier + .. autoclass:: ReduceOp .. class:: reduce_op diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index be2f5536e7ae36..0e8dcbef16c1da 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -89,6 +89,7 @@ Non-linear activation functions sigmoid hardsigmoid silu + mish batch_norm group_norm instance_norm diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 24bba5119cc9de..1859ae5202fd4a 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -145,6 +145,7 @@ Non-linear Activations (weighted sum, nonlinearity) nn.GELU nn.Sigmoid nn.SiLU + nn.Mish nn.Softplus nn.Softshrink nn.Softsign @@ -371,6 +372,7 @@ From the ``torch.nn.utils`` module remove_weight_norm spectral_norm remove_spectral_norm + skip_init Parametrizations implemented using the new parametrization functionality in :func:`torch.nn.utils.parameterize.register_parametrization`. diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index c15a0d0340a5f2..6d0e0e83d3d2b7 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -8,56 +8,6 @@ operations. It's not strictly necessary to understand all this, but we recommend getting familiar with it, as it will help you write more efficient, cleaner programs, and can aid you in debugging. -.. _excluding-subgraphs: - -Excluding subgraphs from backward ---------------------------------- - -Every Tensor has a flag: :attr:`requires_grad` that allows for fine grained -exclusion of subgraphs from gradient computation and can increase efficiency. - -.. _excluding-requires_grad: - -``requires_grad`` -^^^^^^^^^^^^^^^^^ - -If there's a single input to an operation that requires gradient, its output -will also require gradient. Conversely, only if all inputs don't require -gradient, the output also won't require it. Backward computation is never -performed in the subgraphs, where all Tensors didn't require gradients. - -.. code:: - - >>> x = torch.randn(5, 5) # requires_grad=False by default - >>> y = torch.randn(5, 5) # requires_grad=False by default - >>> z = torch.randn((5, 5), requires_grad=True) - >>> a = x + y - >>> a.requires_grad - False - >>> b = a + z - >>> b.requires_grad - True - -This is especially useful when you want to freeze part of your model, or you -know in advance that you're not going to use gradients w.r.t. some parameters. -For example if you want to finetune a pretrained CNN, it's enough to switch the -:attr:`requires_grad` flags in the frozen base, and no intermediate buffers will -be saved, until the computation gets to the last layer, where the affine -transform will use weights that require gradient, and the output of the network -will also require them. - -.. code:: - - model = torchvision.models.resnet18(pretrained=True) - for param in model.parameters(): - param.requires_grad = False - # Replace the last fully-connected layer - # Parameters of newly constructed modules have requires_grad=True by default - model.fc = nn.Linear(512, 100) - - # Optimize only the classifier - optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9) - .. _how-autograd-encodes-history: How autograd encodes the history @@ -86,6 +36,157 @@ flow statements, that can change the overall shape and size of the graph at every iteration. You don't have to encode all possible paths before you launch the training - what you run is what you differentiate. +.. _locally-disable-grad-doc: + +Locally disabling gradient computation +-------------------------------------- + +There are several mechanisms available from Python to locally disable gradient +computation: + +To disable gradients across entire blocks of code, there are context managers +like no-grad mode and inference mode. +For more fine-grained exclusion of subgraphs from gradient computation, +there is setting the ``requires_grad`` field of a tensor. + +Below, in addition to discussing the mechanisms above, we also describe +evaluation mode (:meth:`nn.Module.eval()`), a method that is not actually used +to disable gradient computation but, because of its name, is often mixed up with the three. + +Setting ``requires_grad`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +:attr:`requires_grad` is a flag that allows for fine-grained exclusion of +subgraphs from gradient computation. It takes effect in both the forward +and backward passes: + +During the forward pass, an operation is only recorded in the backward graph if +at least one of its input tensors require grad. +During the backward pass (``.backward()``), only leaf tensors with +``requires_grad=True`` will have gradients accumulated into their ``.grad`` +fields. + +It is important to note that even though every tensor has this flag, +*setting* it only makes sense for leaf tensors (tensors that do not have a +``grad_fn``, e.g., a ``nn.Module``'s parameters). +Non-leaf tensors (tensors that do have ``grad_fn``) are tensors that have a +backward graph associated with them. Thus their gradients will be needed +as an intermediary result to compute the gradient for a leaf tensor that +requires grad. From this definition, it is clear that all non-leaf tensors +will automatically have ``require_grad=True``. + +Setting ``requires_grad`` should be the main way you control which parts +of the model are part of the gradient computation, for example, if you need to +freeze parts of your pretrained model during model fine-tuning. + +To freeze parts of your model, simply apply ``.requires_grad_(False)`` to +the parameters that you don't want updated. And as described above, +since computations that use these parameters as inputs would not be recorded in +the forward pass, they won't have their ``.grad`` fields updated in the backward +pass because they won't be part of the backward graph in the first place, as +desired. + +Because this is such a common pattern, ``requires_grad`` can also be set at +the module level with :meth:`nn.Module.requires_grad_()`. +When applied to a module, ``.requires_grad_()`` takes effect on all +of the module's parameters (which have ``requires_grad=True`` by default). + +Grad Modes +^^^^^^^^^^ + +Apart from setting ``requires_grad`` there are also three possible modes +enableable from Python that can affect how computations in PyTorch are +processed by autograd internally: default mode (grad mode), no-grad mode, +and inference mode, all of which can be togglable via context managers and +decorators. + +Default Mode (Grad Mode) +^^^^^^^^^^^^^^^^^^^^^^^^ + +The "default mode" is actually the mode we are implicitly in when no other modes like +no-grad and inference mode are enabled. To be contrasted with +"no-grad mode" the default mode is also sometimes called "grad mode". + +The most important thing to know about the default mode is that it is the only +mode in which ``requires_grad`` takes effect. ``requires_grad`` is always overridden +to be ``False`` in both the two other modes. + +No-grad Mode +^^^^^^^^^^^^ + +Computations in no-grad mode behave as if none of the inputs require grad. +In other words, computations in no-grad mode are never recorded in the backward graph +even if there are inputs that have ``require_grad=True``. + +Enable no-grad mode when you need to perform operations that should not be +recorded by autograd, but you’d still like to use the outputs of these +computations in grad mode later. This context manager makes it convenient to +disable gradients for a block of code or function without +having to temporarily set tensors to have ``requires_grad=False``, and then +back to ``True``. + +For example, no-grad mode might be useful when writing an optimizer: when +performing the training update you’d like to update parameters +in-place without the update being recorded by autograd. +You also intend to use the updated parameters for computations in +grad mode in the next forward pass. + +The implementations in :ref:`nn-init-doc` also +rely on no-grad mode when initializing the parameters as to avoid +autograd tracking when updating the intialized parameters in-place. + +Inference Mode +^^^^^^^^^^^^^^ + +Inference mode is the extreme version of no-grad mode. Just like in no-grad +mode, computations in inference mode are not recorded in the backward graph, but +enabling inference mode will allow PyTorch to speed up your model even more. +This better runtime comes with a drawback: tensors created in inference mode +will not be able to be used in computations to be recorded by autograd after +exiting inference mode. + +Enable inference mode when you are performing computations that don’t need +to be recorded in the backward graph, AND you don’t plan on using the tensors +created in inference mode in any computation that is to be recorded by autograd later. + +It is recommended that you try out inference mode in the parts of your code +that do not require autograd tracking (e.g., data processing and model evaluation). +If it works out of the box +for your use case it’s a free performance win. If you run into errors after +enabling inference mode, check that you are not using tensors created in +inference mode in computations that are recorded by autograd after exiting inference +mode. If you cannot avoid such use in your case, you can always switch back +to no-grad mode. + +For details on inference mode please see +`Inference Mode `_. + +For implementation details of inference mode see +`RFC-0011-InferenceMode `_. + +Evaluation Mode (``nn.Module.eval()``) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Evaluation mode is not actually a mechanism to locally disable gradient computation. +It is included here anyway because it is sometimes confused to be such a mechanism. + +Functionally, ``module.eval()`` (or equivalently ``module.train()``) are completely +orthogonal to no-grad mode and inference mode. How ``model.eval()`` affects +your model depends entirely on the specific modules used in your model and +whether they define any training-mode specific behavior. + +You are responsible for calling ``model.eval()`` and ``model.train()`` if your +model relies on modules such as :class:`torch.nn.Dropout` and +:class:`torch.nn.BatchNorm2d` that may behave +differently depending on training mode, for example, to avoid updating your +BatchNorm running statistics on validation data. + +It is recommended that you always use ``model.train()`` when +training and ``model.eval()`` when evaluating your model (validation/testing) even +if you aren’t sure your model has training-mode specific behavior, because a +module you are using might be updated to behave differently in training and +eval modes. + In-place operations with autograd --------------------------------- diff --git a/docs/source/scripts/build_activation_images.py b/docs/source/scripts/build_activation_images.py index 7274d5c06c5084..3f4032ae107712 100644 --- a/docs/source/scripts/build_activation_images.py +++ b/docs/source/scripts/build_activation_images.py @@ -37,6 +37,7 @@ 'RReLU', 'SELU', 'SiLU', + 'Mish', 'CELU', 'GELU', 'Sigmoid', diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 8e03ad397913db..9de94c5112559b 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -89,6 +89,7 @@ ("aten::_amp_update_scale", datetime.date(2021, 6, 1)), ("aten::randperm", datetime.date(9999, 1, 1)), ("aten::linalg_vector_norm", datetime.date(2021, 5, 15)), + ("aten::repeat_interleave", datetime.date(2021, 5, 26)), ] def allow_listed(schema, allow_list): diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 768f3c45fb892a..adb1d557902d5b 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -1061,6 +1061,17 @@ TEST_F(FunctionalTest, OneHot) { ASSERT_TRUE(torch::allclose(y, expected)); ASSERT_EQ(y.sizes(), std::vector({3, 2, 3})); } + + { // Test #4 + auto x = torch::arange(0, 5, torch::kLong); + auto y = F::one_hot(x % 3, 3, torch::kUInt8); + auto expected = torch::tensor( + {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kUInt8); + + ASSERT_EQ(y.ndimension(), 2); + ASSERT_TRUE(torch::allclose(y, expected)); + ASSERT_EQ(y.sizes(), std::vector({5, 3})); + } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -1749,6 +1760,15 @@ TEST_F(FunctionalTest, Softsign) { ASSERT_TRUE(torch::allclose(y, y_exp)); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST_F(FunctionalTest, Mish) { + auto x = torch::randn(100) * 10; + auto y_exp = x * x.exp().log1p().tanh(); + auto y = F::mish(x); + + ASSERT_TRUE(torch::allclose(y, y_exp)); +} + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST_F(FunctionalTest, Tanhshrink) { auto x = torch::randn(100) * 10; diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index f2c945fa800e2f..4b22a383437628 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2958,6 +2958,16 @@ TEST_F(ModulesTest, GELU) { ASSERT_TRUE(torch::allclose(y, y_exp)); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST_F(ModulesTest, Mish) { + Mish model; + auto x = torch::randn(100) * 10; + auto y_exp = x * x.exp().log1p().tanh(); + auto y = model(x); + + ASSERT_TRUE(torch::allclose(y, y_exp)); +} + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST_F(ModulesTest, Sigmoid) { Sigmoid model; diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 5ad08610bec6c6..3cb1cb6fe2538a 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -65,6 +65,7 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp ${JIT_TEST_ROOT}/test_subgraph_utils.cpp ${JIT_TEST_ROOT}/test_utils.cpp + ${JIT_TEST_ROOT}/test_script_profile.cpp ) if(USE_CUDA) diff --git a/test/cpp/jit/script_module_v6.ptl b/test/cpp/jit/script_module_v6.ptl new file mode 100644 index 00000000000000..a5945910fbd192 Binary files /dev/null and b/test/cpp/jit/script_module_v6.ptl differ diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index fcd626d9209083..bf4b48d3e233c4 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include // Tests go in torch::jit @@ -164,5 +165,415 @@ TEST(BackendTest, TestCompilerNotSupport) { "backend_with_compiler_demo", m, compile_spec, any_dict_ty), "The node of aten::mul is not supported in this compiler. Source code:"); } + +TEST(BackendTestDebugInfo, TestCompiler) { + Module m("m"); + m.define(R"( + def forward(self, x, h): + return x + h + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + c10::Dict compile_spec(StringType::get(), AnyType::get()); + c10::Dict fake_dict(StringType::get(), AnyType::get()); + fake_dict.insert("", ""); + compile_spec.insert("forward", fake_dict); + auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); + // lowered module + auto lm = torch::jit::detail::codegen_backend_module( + "backend_with_compiler_demo", m, compile_spec, any_dict_ty); + + std::stringstream ss; + lm._save_for_mobile(ss, ExtraFilesMap(), true); + auto mlm = _load_for_mobile(ss); + std::string error_pattern = R"( + Module hierarchy:top(backend_with_compiler_demoLoweredModule).aten::add +Traceback of TorchScript (most recent call last): + File "", line 5, in FunctionName_UNKNOWN + typed_inputs: List[Any] = [x, h, ] + if self.__backend.is_available() : + _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) + ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + assert isinstance(_0, Tensor) + return _0 + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, h): + return x + h + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); +} + +TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) { + Module a("A"); + a.define(R"( + def forward(self, x, y): + return x + y + )"); + Module b("B"); + b.define(R"( + def forward(self, x): + return x + 2 + )"); + Module c("C"); + c.register_module("A0", a); + c.register_module("B0", b); + c.define(R"( + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + c10::Dict compile_spec(StringType::get(), AnyType::get()); + c10::Dict fake_dict(StringType::get(), AnyType::get()); + fake_dict.insert("", ""); + compile_spec.insert("forward", fake_dict); + auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); + // lowered module + auto lm = torch::jit::detail::codegen_backend_module( + "backend_with_compiler_demo", c, compile_spec, any_dict_ty); + + std::stringstream ss; + lm._save_for_mobile(ss, ExtraFilesMap(), true); + auto mlm = _load_for_mobile(ss); + std::string error_pattern = R"( + Module hierarchy:top(backend_with_compiler_demoLoweredModule).A0(A).aten::add +Traceback of TorchScript (most recent call last): + File "", line 5, in FunctionName_UNKNOWN + typed_inputs: List[Any] = [x, y, ] + if self.__backend.is_available() : + _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) + ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + assert isinstance(_0, Tensor) + return _0 + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + ~~~~~~~~~~~~~~~ <--- HERE + + File "", line 3, in forward + + def forward(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); +} + +TEST( + BackendTestDebugInfo, + TestExceptionStackForCompilerWithTwoLevelModuleHierarchy) { + Module a("A"); + a.define(R"( + def forward(self, x, y): + return x + y + )"); + Module b("B"); + b.register_module("A0", a); + b.define(R"( + def forward(self, x, y): + return self.A0.forward(x, y) + 2 + )"); + Module c("C"); + c.register_module("B0", b); + c.define(R"( + def forward(self, x, y): + return self.B0.forward(x, y) + 3 + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + c10::Dict compile_spec(StringType::get(), AnyType::get()); + c10::Dict fake_dict(StringType::get(), AnyType::get()); + fake_dict.insert("", ""); + compile_spec.insert("forward", fake_dict); + auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); + // lowered module + auto lm = torch::jit::detail::codegen_backend_module( + "backend_with_compiler_demo", c, compile_spec, any_dict_ty); + + std::stringstream ss; + lm._save_for_mobile(ss, ExtraFilesMap(), true); + auto mlm = _load_for_mobile(ss); + /* + * Error stack throw will look like this: + * Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A) + * Traceback of TorchScript (most recent call last): + * File "", line 5, in FunctionName_UNKNOWN + * typed_inputs: List[Any] = [x, y, ] + * if self.__backend.is_available() : + * _0, = self.__backend.execute(self.__handles["forward"], + * typed_inputs) + * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + * assert isinstance(_0, Tensor) + * return _0 + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return self.B0.forward(x, y) + 3 + * ~~~~~~~~~~~~~~~ <--- HERE + * + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return self.A0.forward(x, y) + 2 + * ~~~~~~~~~~~~~~~ <--- HERE + * + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return x + y + * ~~~~~ <--- HERE + * + */ + std::string error_pattern = R"( + Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A).aten::add +Traceback of TorchScript (most recent call last): + File "", line 5, in FunctionName_UNKNOWN + typed_inputs: List[Any] = [x, y, ] + if self.__backend.is_available() : + _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) + ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + assert isinstance(_0, Tensor) + return _0 + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.B0.forward(x, y) + 3 + ~~~~~~~~~~~~~~~ <--- HERE + + File "", line 3, in forward + + def forward(self, x, y): + return self.A0.forward(x, y) + 2 + ~~~~~~~~~~~~~~~ <--- HERE + + File "", line 3, in forward + + def forward(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); +} + +TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) { + std::shared_ptr cu = std::make_shared(); + Module a("A"); + a.define(R"( + def forward(self, x, y): + return x + y + )"); + Module b("B"); + b.define(R"( + def forward(self, x): + return x + 2 + )"); + Module c("C"); + c.register_module("A0", a); + c.register_module("B0", b); + c.define(R"( + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + c10::Dict compile_spec(StringType::get(), AnyType::get()); + c10::Dict fake_dict(StringType::get(), AnyType::get()); + fake_dict.insert("", ""); + compile_spec.insert("forward", fake_dict); + IValue submodule = c.attr("A0"); + Module current_sm = submodule.toModule(); + auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); + // lowered module + auto lowered_submodule = torch::jit::detail::codegen_backend_module( + "backend_with_compiler_demo", current_sm, compile_spec, any_dict_ty); + + c.type()->unsafeChangeAttributeType("A0", lowered_submodule.type()); + c.setattr("A0", lowered_submodule._ivalue()); + std::unordered_map type_remap; + type_remap[a.type()] = lowered_submodule.type(); + auto type_remap_fn = [&type_remap](TypePtr in) { + auto it = type_remap.find(in); + if (it == type_remap.end()) + return in; + return it->second; + }; + for (auto& fn : c.type()->methods()) { + auto method = c.get_method(fn->name()); + auto graph = method.graph(); + graph->remapTypes(type_remap_fn); + auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); + fn->setSchema(new_schema); + } + + std::stringstream ss; + c._save_for_mobile(ss, ExtraFilesMap(), true); + auto c_loaded = _load_for_mobile(ss); + std::string error_pattern = R"( + Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).aten::add +Traceback of TorchScript (most recent call last): + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + ~~~~~~~~~~~~~~~ <--- HERE + + File "", line 5, in forward + typed_inputs: List[Any] = [x, y, ] + if self.__backend.is_available() : + _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) + ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + assert isinstance(_0, Tensor) + return _0 + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern); +} + +TEST( + BackendTestDebugInfo, + TestExceptionStackForCompilerWithSelectiveLoweredSubModule) { + std::shared_ptr cu = std::make_shared(); + Module aa("AA"); + aa.define(R"( + def forward(self, x, y): + return x + y + )"); + Module a("A"); + a.register_module("AA0", aa); + a.define(R"( + def forward(self, x, y): + return self.AA0.forward(x, y) + 3 + )"); + Module b("B"); + b.define(R"( + def forward(self, x): + return x + 2 + )"); + Module c("C"); + c.register_module("A0", a); + c.register_module("B0", b); + c.define(R"( + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + c10::Dict compile_spec(StringType::get(), AnyType::get()); + c10::Dict fake_dict(StringType::get(), AnyType::get()); + fake_dict.insert("", ""); + compile_spec.insert("forward", fake_dict); + IValue submodule = c.attr("A0"); + Module current_sm = submodule.toModule(); + auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); + // lowered module + auto lowered_submodule = torch::jit::detail::codegen_backend_module( + "backend_with_compiler_demo", current_sm, compile_spec, any_dict_ty); + + c.type()->unsafeChangeAttributeType("A0", lowered_submodule.type()); + c.setattr("A0", lowered_submodule._ivalue()); + std::unordered_map type_remap; + type_remap[a.type()] = lowered_submodule.type(); + auto type_remap_fn = [&type_remap](TypePtr in) { + auto it = type_remap.find(in); + if (it == type_remap.end()) + return in; + return it->second; + }; + for (auto& fn : c.type()->methods()) { + auto method = c.get_method(fn->name()); + auto graph = method.graph(); + graph->remapTypes(type_remap_fn); + auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); + fn->setSchema(new_schema); + } + + std::stringstream ss; + c._save_for_mobile(ss, ExtraFilesMap(), true); + auto c_loaded = _load_for_mobile(ss); + /* + * Erro stack trace will look like this: + * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) + * Traceback of TorchScript (most recent call last): + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return self.A0.forward(x, y) + self.B0.forward(x) + * ~~~~~~~~~~~~~~~ <--- HERE + * + * File "", line 5, in FunctionName_UNKNOWN + * typed_inputs: List[Any] = [x, y, ] + * if self.__backend.is_available() : + * _0, = self.__backend.execute(self.__handles["forward"], + * typed_inputs) + * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + * assert isinstance(_0, Tensor) + * return _0 + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return self.AA0.forward(x, y) + 3 + * ~~~~~~~~~~~~~~~~ <--- HERE + * + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return x + y + * ~~~~~ <--- HERE + * + * + * */ + std::string error_pattern = R"( + Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA).aten::add +Traceback of TorchScript (most recent call last): + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + ~~~~~~~~~~~~~~~ <--- HERE + + File "", line 5, in forward + typed_inputs: List[Any] = [x, y, ] + if self.__backend.is_available() : + _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) + ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + assert isinstance(_0, Tensor) + return _0 + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.AA0.forward(x, y) + 3 + ~~~~~~~~~~~~~~~~ <--- HERE + + File "", line 3, in forward + + def forward(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 3a4543ddd37225..def0e5dbda85ba 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -1,4 +1,5 @@ #include +#include namespace torch { namespace jit { @@ -31,13 +32,22 @@ namespace jit { // (handle). namespace { -std::vector parseMethodHandle(const std::string& blob) { - std::vector result; +std::vector> parseMethodHandle( + const std::string& blob) { + std::vector> result; std::stringstream s_stream(blob); + constexpr char debug_handle_token[] = ""; while (s_stream.good()) { std::string substr; getline(s_stream, substr, ','); - result.push_back(substr); + auto debug_handle_pos = substr.find(debug_handle_token); + int64_t debug_handle{-1}; + auto instruction = substr.substr(0); + if (debug_handle_pos != std::string::npos) { + instruction = substr.substr(0, debug_handle_pos); + debug_handle = stoi(substr.substr(debug_handle_pos + 14)); + } + result.push_back(std::make_tuple(instruction, debug_handle)); } return result; } @@ -60,7 +70,8 @@ class BackendWithCompiler : public PyTorchBackendInterface { c10::IValue processed, c10::impl::GenericDict method_compile_spec) override { auto dict = processed.toGenericDict(); - auto handles = c10::Dict>(); + auto handles = + c10::Dict>>(); for (const auto& kv : dict) { auto tokens = parseMethodHandle(kv.value().toStringRef()); handles.insert(kv.key().toStringRef(), tokens); @@ -81,27 +92,33 @@ class BackendWithCompiler : public PyTorchBackendInterface { double scalar_val = 1.0; for (const auto& token : handle.toList()) { IValue val = token; - auto instruction = std::string(IValue(token).toStringRef()); + auto instruction = val.toTuple()->elements()[0].toStringRef(); + auto debug_handle = val.toTuple()->elements()[1].toInt(); double const_val = 1.0; - if (instruction.rfind("prim::Constant", 0) == 0) { - TORCH_CHECK( - instruction.size() > 15, - "Constant value is expected in ", - instruction); - auto sub = instruction.substr(15); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - const_val = stod(sub); - } else if (token == "aten::add") { - output_list.emplace_back(x.add(h, const_val)); - } else if (token == "aten::sub") { - output_list.emplace_back(x.sub(h, const_val)); - } else { - TORCH_CHECK( - false, - "Instruction, ", - instruction, - " is not supported. ", - "Contact the backend POC for details. "); + try { + if (instruction.rfind("prim::Constant", 0) == 0) { + TORCH_CHECK( + instruction.size() > 15, + "Constant value is expected in ", + instruction); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto sub = instruction.substr(15); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + const_val = stod(sub); + } else if (instruction == "aten::add") { + output_list.emplace_back(x.add(h, const_val)); + } else if (instruction == "aten::sub") { + output_list.emplace_back(x.sub(h, const_val)); + } else { + TORCH_CHECK( + false, + "Instruction, ", + instruction, + " is not supported. ", + "Contact the backend POC for details. "); + } + } catch (c10::Error& e) { + TORCH_DELEGATED_BACKEND_THROW(false, e.what(), debug_handle); } } return c10::impl::toList(output_list); diff --git a/test/cpp/jit/test_backend_compiler_preprocess.cpp b/test/cpp/jit/test_backend_compiler_preprocess.cpp index 05d0f8fab0e41b..3de692bd485617 100644 --- a/test/cpp/jit/test_backend_compiler_preprocess.cpp +++ b/test/cpp/jit/test_backend_compiler_preprocess.cpp @@ -1,5 +1,8 @@ #include #include +#include +#include +#include namespace torch { namespace jit { @@ -15,23 +18,35 @@ c10::IValue preprocess( // Key: method name. // Val: compiled blob (represented by a string). c10::Dict compiled(StringType::get(), StringType::get()); + for (const auto& method : mod.get_methods()) { - const auto graph = method.function().graph()->copy(); + auto graph = method.function().graph()->copy(); + // Must inline the graph for debug info map. + Inline(*graph); + // This is here because to test module hierarchy we will have + // getattr nodes which after inlining dont serve any purpose. + // Without removing them we will run into compilation errors. + // So eliminate deadcode just remove those getattr nodes. + EliminateDeadCode(graph); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) auto key = method.name(); + auto node_debug_handles = generate_debug_handles(graph); std::stringstream ss; for (const auto& node : graph->nodes()) { switch (node->kind()) { case prim::Constant: ss << node->kind().toDisplayString() << "#" << toIValue(node->output()).value(); + ss << "" << node_debug_handles[node]; break; // NOLINTNEXTLINE(bugprone-branch-clone) case aten::add: ss << node->kind().toQualString(); + ss << "" << node_debug_handles[node]; break; case aten::sub: ss << node->kind().toQualString(); + ss << "" << node_debug_handles[node]; break; default: TORCH_CHECK( diff --git a/test/cpp/jit/test_cs_debug_info_serialization.cpp b/test/cpp/jit/test_cs_debug_info_serialization.cpp index a9d9ade034fed7..c34f0da1b636fe 100644 --- a/test/cpp/jit/test_cs_debug_info_serialization.cpp +++ b/test/cpp/jit/test_cs_debug_info_serialization.cpp @@ -25,38 +25,57 @@ namespace jit { namespace { bool validate_debug_info( - const DebugInfoPair& pre_serialize, - const DebugInfoPair& post_serialize) { - auto sr1 = pre_serialize.first; - auto sr2 = post_serialize.first; + const DebugInfoTuple& pre_serialize, + const DebugInfoTuple& post_serialize) { + auto sr1 = std::get(pre_serialize); + auto sr2 = std::get(post_serialize); if (sr1 != sr2) { return false; } - if (!pre_serialize.second.defined()) { - return !post_serialize.second.defined(); + auto csptr1 = std::get(pre_serialize); + auto csptr2 = std::get(post_serialize); + if (!csptr1.defined()) { + return !csptr2.defined(); } - if (!post_serialize.second.defined()) { + if (!csptr2.defined()) { return false; } - auto vec1 = pre_serialize.second->vec(); - auto vec2 = post_serialize.second->vec(); + auto vec1 = csptr1->vec(); + auto vec2 = csptr2->vec(); if (vec1.size() != vec2.size()) { return false; } - for (size_t i = 0; i < vec1.size(); i++) { - auto rhs_sr = std::get<1>(vec1[i]); - auto lhs_sr = std::get<1>(vec2[i]); - auto rhs_module = std::get<2>(vec1[i]); - auto lhs_module = std::get<2>(vec2[i]); + while (csptr1) { + auto rhs_sr = csptr1->source_range(); + auto lhs_sr = csptr2->source_range(); + auto rhs_module = csptr1->module_instance(); + auto lhs_module = csptr2->module_instance(); + std::string rhs_fn_name, lhs_fn_name; + if (csptr1->function()) { + rhs_fn_name = csptr1->function()->name(); + } else { + rhs_fn_name = csptr1->function_name(); + } + if (csptr2->function()) { + lhs_fn_name = csptr2->function()->name(); + } else { + lhs_fn_name = csptr2->function_name(); + } if (!((rhs_module.has_value() == lhs_module.has_value()) && (rhs_module.has_value() && (rhs_module.value().class_type()->name().value() == lhs_module.value().class_type()->name().value()) && (rhs_module.value().instance_name() == lhs_module.value().instance_name())) && - (rhs_sr == lhs_sr))) { + (rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) { return false; } + if (csptr1->callee()) { + csptr1 = csptr1->callee().value(); + csptr2 = csptr2->callee().value(); + } else { + csptr1 = c10::intrusive_ptr(); + } } return true; } @@ -81,7 +100,7 @@ TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) { return self.A0.forward(x) + self.B0.forward(x) )JIT"); - BackendDebugHandleManager debug_handle_manager; + BackendDebugInfoRecorder debug_info_recorder; auto graph = c.get_method("forward").graph(); Inline(*graph); std::stack blocks_to_visit; @@ -100,7 +119,7 @@ TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) { source_range_tags[n->sourceRange()] = source_range_tag; source_range_map[source_range_tag] = n->sourceRange(); source_range_tag++; - debug_handle_manager.getNextDebugHandleForInlinedCallStackPtr(n); + debug_info_recorder.getNextDebugHandle(n); if (n->callstack().has_value()) { for (const auto& e : n->callstack().value()->vec()) { auto sr = std::get<1>(e); @@ -111,7 +130,7 @@ TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) { } } } - auto debug_handle_cs_ptr_map = debug_handle_manager.getCallStackPtrMap(); + auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording(); CallStackDebugInfoPickler cs_debug_info_pickler; auto cs_data = cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags); diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index ece646f6ede8a3..fe019a67512dfd 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -496,8 +496,7 @@ TEST(LiteInterpreterTest, ModuleInfoBasic) { } } - std::unordered_set expected_result({"top(M)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(M).aten::mul")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -559,8 +558,9 @@ TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) { } } - std::set expected_result({"top(B)", "top(B).A0(A)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(B).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::mul")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -594,7 +594,6 @@ TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { std::string module_info = bc.get_forward_method_debug_info(pc); if (!module_info.empty() && (module_info.find("debug_handle") == std::string::npos)) { - std::cout << "Module info:" << module_info << std::endl; module_debug_info_set.insert(module_info); } ++pc; @@ -603,9 +602,9 @@ TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { } } - std::set expected_result( - {"top(C)", "top(C).A0(A)", "top(C).B0(B)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(C).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add")); } TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) { @@ -625,6 +624,34 @@ TEST(LiteInterpreterTest, GetByteCodeVersion) { } namespace { + +void compareModelOutput( + const std::vector& actual_result_list, + const std::vector& expect_result_list) { + AT_ASSERT(actual_result_list.size() == expect_result_list.size()); + AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0])); + AT_ASSERT( + actual_result_list[1].toTensor().dim() == expect_result_list[1].dim()); + AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2])); +} + +void runAndCheckTorchScriptModel( + std::stringstream& input_model_stream, + const std::vector& input_data, + const std::vector& expect_result_list, + const int64_t expect_version) { + auto actual_version = _get_model_bytecode_version(input_model_stream); + AT_ASSERT(actual_version == expect_version); + + // Load and run the backport model, then compare the result with expect + // result + Module m_mobile = load(input_model_stream); + + auto actual_result = m_mobile.forward(input_data); + std::vector actual_result_list = actual_result.toTuple()->elements(); + compareModelOutput(actual_result_list, expect_result_list); +} + void runAndCheckBytecodeModel( std::stringstream& input_model_stream, const std::vector& input_data, @@ -635,16 +662,12 @@ void runAndCheckBytecodeModel( // Load and run the backport model, then compare the result with expect // result - mobile::Module m_mobile = _load_for_mobile(input_model_stream); + Module m_mobile = load(input_model_stream); auto actual_result = m_mobile.forward(input_data); std::vector actual_result_list = actual_result.toTuple()->elements(); - AT_ASSERT(actual_result_list.size() == expect_result_list.size()); - AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0])); - AT_ASSERT( - actual_result_list[1].toTensor().dim() == expect_result_list[1].dim()); - AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2])); + compareModelOutput(actual_result_list, expect_result_list); } void backportAllVersionCheck( @@ -659,29 +682,33 @@ void backportAllVersionCheck( constexpr int64_t minimum_to_version = 4; int64_t current_to_version = from_version - 1; - std::ostringstream oss; // Verify all candidate to_version work as expected. All backport to version // larger than minimum_to_version should success. while (current_to_version >= minimum_to_version) { - oss.clear(); + // Do not declare std::stringstream oss outside of the while loop as + // oss.clear() doesn't reset the stream content, only clears out error state + // flag in stringstream causing a problematic stream. Instead, it's cleaner + // and safer to just declare a new std::stringstream one and swap them. + std::stringstream oss; bool backPortSuccess = _backport_for_mobile(test_model_file_stream, oss, current_to_version); AT_ASSERT(backPortSuccess); // Check backport model version - std::stringstream iss(oss.str()); - auto backport_version = _get_model_bytecode_version(iss); + auto backport_version = _get_model_bytecode_version(oss); AT_ASSERT(backport_version == current_to_version); // Load and run the backport model, then compare the result with expect // result runAndCheckBytecodeModel( - iss, input_data, expect_result_list, current_to_version); + oss, input_data, expect_result_list, current_to_version); + runAndCheckTorchScriptModel( + oss, input_data, expect_result_list, current_to_version); current_to_version--; } // backport to minimum version - 1 should fail - oss.clear(); + std::stringstream oss; bool backPortSuccess = _backport_for_mobile(test_model_file_stream, oss, minimum_to_version - 1); AT_ASSERT(!backPortSuccess); @@ -790,9 +817,9 @@ TEST(LiteInterpreterTest, SequentialModuleInfo) { // def forward(self, x): // return self.A0.forward(self.B0.forward(x)) - std::set expected_result( - {"top(C)", "top(C).A0(A)", "top(C).B0(B)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(C).prim::Return")); + AT_ASSERT(module_debug_info_set.count("top(C).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -838,9 +865,9 @@ TEST(LiteInterpreterTest, HierarchyModuleInfo) { // "top(C).forward": for the add operator in top. // "top(C).B0(B).forward": for the add operator in B0. // "top(C).B0(B).forward.A0(A).forward": for the add operator in A0. - std::set expected_result( - {"top(C)", "top(C).B0(B)", "top(C).B0(B).A0(A)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(C).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(C).B0(B).A0(A).aten::add")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -898,9 +925,9 @@ TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) { // "top(B).A0(A).forward": for the add operator in A0. // "top(B).A1(A).forward": for the add operator in A1. - std::set expected_result( - {"top(B)", "top(B).A0(A)", "top(B).A1(A)"}); - AT_ASSERT(module_debug_info_set == expected_result); + AT_ASSERT(module_debug_info_set.count("top(B).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::add")); + AT_ASSERT(module_debug_info_set.count("top(B).A1(A).aten::add")); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -1286,6 +1313,57 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) { testLiteModuleCompareResultTensors(m, inputs); } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) { + Module a("A"); + a.define(R"( + def bar(self, x, y): + return x + y + )"); + Module b("B"); + b.register_module("A0", a); + b.define(R"( + def foo(self, x, y): + return self.A0.bar(x, y) + 2 + )"); + Module c("C"); + c.register_module("B0", b); + c.define(R"( + def forward(self, x, y): + return self.B0.foo(x, y) + 3 + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + std::stringstream ss; + c._save_for_mobile(ss, ExtraFilesMap(), true); + auto lite_m = _load_for_mobile(ss); + std::string error_pattern = R"( + Module hierarchy:top(C).B0(B).A0(A).aten::add +Traceback of TorchScript (most recent call last): + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.B0.foo(x, y) + 3 + ~~~~~~~~~~~ <--- HERE + + File "", line 3, in foo + + def foo(self, x, y): + return self.A0.bar(x, y) + 2 + ~~~~~~~~~~~ <--- HERE + + File "", line 3, in bar + + def bar(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern); +} + namespace { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static auto reg = diff --git a/test/cpp/jit/test_script_profile.cpp b/test/cpp/jit/test_script_profile.cpp new file mode 100644 index 00000000000000..242a15eed9bfc4 --- /dev/null +++ b/test/cpp/jit/test_script_profile.cpp @@ -0,0 +1,62 @@ +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +TEST(ScriptProfileTest, Basic) { + const std::string source_string = R"V0G0N( + def foo(a, b): + return a + b # + )V0G0N"; + auto begin = source_string.find("return"); + auto end = source_string.find(" #"); + + Graph g; + const auto graph_string = R"IR( + graph(%a : Tensor, + %b : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : Tensor = aten::add(%a, %b, %2) + return (%3))IR"; + + torch::jit::parseIR(graph_string, &g); + auto source = std::make_shared(source_string, "", 0); + auto node = *g.nodes().begin(); + node->setSourceRange(SourceRange{source, begin, end}); + + ScriptProfile p; + p.enable(); + { + profiling::InstructionSpan g0(*node); + profiling::InstructionSpan g1(*node); + profiling::InstructionSpan g2(*node); + } + p.disable(); + + auto stats = p.dumpStats(); + EXPECT_EQ(stats.size(), 1); + auto it = stats.find(*source.get()); + EXPECT_NE(it, stats.end()); + auto& lines = it->second; + EXPECT_EQ(lines.size(), 1); + const auto& stat = lines.at(source->lineno_for_offset(begin)); + EXPECT_EQ(stat.count, 3); +} + +TEST(ScriptProfileTest, CallingOrder) { + ScriptProfile p; + p.enable(); + EXPECT_THROW(p.dumpStats(), c10::Error); + p.disable(); + auto dp = std::make_shared(SourceRange{}); + EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/test_utils.h b/test/cpp/jit/test_utils.h index 7edf9d84dbac51..676759dca480fe 100644 --- a/test/cpp/jit/test_utils.h +++ b/test/cpp/jit/test_utils.h @@ -5,12 +5,49 @@ #include #include -#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \ - try { \ - (void)statement; \ - FAIL(); \ - } catch (const std::exception& e) { \ - ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ +namespace { +static inline void trim(std::string& s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); + s.erase( + std::find_if( + s.rbegin(), + s.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + s.end()); + for (int64_t i = 0; i < s.size(); ++i) { + if (s[i] == '\n') { + s.erase(i, 1); + i--; + } + } + for (int64_t i = 0; i < s.size(); ++i) { + if (s[i] == ' ') { + for (int64_t j = i + 1; j < s.size(); j++) { + if (s[j] == ' ') { + s.erase(j, 1); + j--; + } else { + break; + } + } + } + } +} +} // namespace + +#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \ + try { \ + (void)statement; \ + FAIL(); \ + } catch (const std::exception& e) { \ + std::string substring_s(substring); \ + trim(substring_s); \ + auto exception_string = std::string(e.what()); \ + trim(exception_string); \ + ASSERT_NE(exception_string.find(substring_s), std::string::npos); \ } namespace torch { diff --git a/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl b/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl new file mode 100644 index 00000000000000..901724d82225b8 Binary files /dev/null and b/test/cpp/lite_interpreter_runtime/delegated_submodule_with_debug_info.ptl differ diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index ae58b99676a6f1..2ccf6ee18d3a81 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -71,6 +72,105 @@ TEST(RunTimeTest, Delegate) { auto mres = mlm.forward(inputs); AT_ASSERT(mres.toTensor().equal(3 * at::ones({}))); } + +TEST(RunTimeTest, DelegateException) { + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + /* + * Model: delegated_submodule_with_debug_info.ptl + * Model structure: + * def AA(..): + * def forward(self, x, y): + * return x + y + * + * def A(..): + * def __init__(..): + * self.AA0 = AA() + * def forward(self, x, y): + * return self.AA0.forward(x, y) + 3 + * + * def B(..): + * def forward(self, x): + * return x + 2 + * + * def C(..): + * def __init__(..): + * self.A0 = A() + * self.B0 = B() + * def forward(self, x, y): + * return self.A0.forward(x, y) + self.B0.forward(x) + * + * std::vector inputs; + * inputs.emplace_back(torch::rand({2, 4})); + * inputs.emplace_back(torch::rand({13, 9})); + * Run with inputs and expect exception + * Erro stack trace will look like this: + * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) + * Traceback of TorchScript (most recent call last): + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return self.A0.forward(x, y) + self.B0.forward(x) + * ~~~~~~~~~~~~~~~ <--- HERE + * + * File "", line 5, in FunctionName_UNKNOWN + * typed_inputs: List[Any] = [x, y, ] + * if self.__backend.is_available() : + * _0, = self.__backend.execute(self.__handles["forward"], + * typed_inputs) + * ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + * assert isinstance(_0, Tensor) + * return _0 + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return self.AA0.forward(x, y) + 3 + * ~~~~~~~~~~~~~~~~ <--- HERE + * + * File "", line 3, in FunctionName_UNKNOWN + * + * def forward(self, x, y): + * return x + y + * ~~~~~ <--- HERE + * + * + */ + testModelFile.append("delegated_submodule_with_debug_info.ptl"); + auto mlm = _load_for_mobile(testModelFile); + std::vector inputs; + inputs.emplace_back(torch::rand({2, 4})); + inputs.emplace_back(torch::rand({13, 9})); + + std::string error_pattern = R"( + Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA).aten::add +Traceback of TorchScript (most recent call last): + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.A0.forward(x, y) + self.B0.forward(x) + ~~~~~~~~~~~~~~~ <--- HERE + + File "", line 5, in forward + typed_inputs: List[Any] = [x, y, ] + if self.__backend.is_available() : + _0, = self.__backend.execute(self.__handles["forward"], typed_inputs) + ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + assert isinstance(_0, Tensor) + return _0 + File "", line 3, in FunctionName_UNKNOWN + + def forward(self, x, y): + return self.AA0.forward(x, y) + 3 + ~~~~~~~~~~~~~~~~ <--- HERE + + File "", line 3, in forward + + def forward(self, x, y): + return x + y + ~~~~~ <--- HERE + )"; + ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern); +} } // namespace mobile } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp index 5a56771990fc42..6bd31e2ef047e2 100644 --- a/test/cpp/tensorexpr/test_approx.cpp +++ b/test/cpp/tensorexpr/test_approx.cpp @@ -13,8 +13,8 @@ namespace te = torch::jit::tensorexpr; static void vectorize(te::LoopNest* ln, te::Tensor* target, int width) { auto loops = ln->getLoopStmtsFor(target); - te::For *outer, *inner, *tail; - ln->splitWithTail(loops[0], width, &outer, &inner, &tail); + te::For *inner, *tail; + ln->splitWithTail(loops[0], width, &inner, &tail); ln->vectorize(inner); } diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp index b3bc26b51da4db..87fb244e0cb9a8 100644 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -217,14 +217,13 @@ TEST(BoundsInference, _5) { Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* tail; std::vector loops = l.getLoopStmtsFor(b); - l.splitWithTail(loops[0], 16, &outer, &inner, &tail); + l.splitWithTail(loops[0], 16, &inner, &tail); + For* outer = loops[0]; { // Verify inferred bounds for the outer loop @@ -729,11 +728,13 @@ TEST(BoundsInference, GetPotentialHazardsLoopSplit) { LoopNest l({A}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *outer, *inner, *tail; + For *inner, *tail; // Splitting with tail by something offset creates a tail which also writes to // A. - l.splitWithTail(l.getLoopStmtsFor(A)[0], 5, &outer, &inner, &tail); + For* outer = l.getLoopStmtsFor(A)[0]; + // `outer` loop get transformed to the outer loop after splitting. + l.splitWithTail(outer, 5, &inner, &tail); using namespace analysis; diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 71f6967da94421..a12592939a0f6a 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -6,14 +6,14 @@ #include -#include "test/cpp/tensorexpr/test_base.h" +#include +#include +#include +#include +#include +#include #include -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/loopnest.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" #include @@ -172,11 +172,10 @@ static void testCudaTestVectorAdd02_impl(int N, int block_size) { }, [&](const VarHandle& n) { return a_buf.load(n) + b_buf.load(n); }); LoopNest l({c}); - For* n_outer; For* n_inner; std::vector loops = l.getLoopStmtsFor(c); - l.splitWithMask(loops[0], block_size, &n_outer, &n_inner); - l.setGPUBlockIndex(n_outer, 0); + l.splitWithMask(loops[0], block_size, &n_inner); + l.setGPUBlockIndex(loops[0], 0); l.setGPUThreadIndex(n_inner, 0); l.prepareForCodegen(); Stmt* stmt = l.root_stmt(); @@ -391,11 +390,10 @@ TEST(Cuda, DynamicShapeSplit_CUDA) { Tensor* b = Compute( "b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); LoopNest l({b}); - For* outer; For* inner; std::vector loops = l.getLoopStmtsFor(b); - l.splitWithMask(loops[0], 1024, &outer, &inner); - l.setGPUBlockIndex(outer, 0); + l.splitWithMask(loops[0], 1024, &inner); + l.setGPUBlockIndex(loops[0], 0); l.setGPUThreadIndex(inner, 0); Stmt* s = l.root_stmt(); CudaCodeGen cg(s, {a, b, n}); diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 6c22b0310efa81..06113640714a24 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -1721,16 +1721,12 @@ TEST(LLVM, VectorizedGEMM) { { auto const& loops = loop.getLoopStmtsFor(CT); For* m = loops[0]; - For* mo; - For* mi; - loop.splitWithMask(m, 16, &mo, &mi); + loop.splitWithMask(m, 16); } { auto const& loops = loop.getLoopStmtsFor(CT); For* n = loops[2]; - For* no; - For* ni; - loop.splitWithMask(n, 16, &no, &ni); + loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 6522ed7d703335..c0860bc0d47877 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -36,16 +36,10 @@ TEST(LoopNest, ExprSimple01) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_outer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); - l.splitWithTail(loops[0], 2, &x_outer, &x_inner, &x_tail); - l.splitWithTail(x_outer, 2); + l.splitWithTail(loops[0], 2); + l.splitWithTail(loops[0], 2); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -395,8 +389,6 @@ TEST(LoopNest, ExprSplitAndSlice) { Tensor* tensor = Compute("f", {{100, "x"}}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -405,9 +397,9 @@ TEST(LoopNest, ExprSplitAndSlice) { // outer: [0, 4) // inner: [0, 21) // tail: [84, 100) - l.splitWithTail(loops[0], 21, &outer, &inner, &tail); + l.splitWithTail(loops[0], 21, &inner, &tail); l.sliceTail(inner, 2); - l.sliceHead(outer, 2); + l.sliceHead(loops[0], 2); // for (int x_outer = 0; x_outer < 2; x_outer++) { // for (int x_inner = 0; x_inner < 19; x_inner++) { @@ -522,15 +514,11 @@ TEST(LoopNest, ExprSplitWithTail) { }; Tensor* tensor = Compute("f", {{199, "x"}}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_outer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); - l.splitWithTail(loops[0], 17, &x_outer, &x_inner, &x_tail); - l.splitWithTail(x_outer, 7); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + l.splitWithTail(loops[0], 17); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + l.splitWithTail(loops[0], 7); Stmt* stmt = l.root_stmt(); Stmt* simplified = IRSimplifier::simplify(stmt); @@ -557,14 +545,8 @@ TEST(LoopNest, ExprSplitWithTailNone) { }; Tensor* tensor = Compute("f", {{24, "x"}, {5, "y"}}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_outer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); - l.splitWithTail(loops[0], 4, &x_outer, &x_inner, &x_tail); + l.splitWithTail(loops[0], 4); Stmt* stmt = l.root_stmt(); std::ostringstream oss; @@ -663,10 +645,8 @@ TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { LoopNest l({tensor}); std::vector loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *outer, *mid, *inner; - l.splitWithMask(loops[0], 4, &outer, &inner); - l.splitWithMask(outer, 4); + l.splitWithMask(loops[0], 4); + l.splitWithMask(loops[0], 4); Stmt* stmt1 = IRSimplifier::simplify(l.root_stmt()); @@ -691,16 +671,16 @@ TEST(LoopNest, SplitWithTailWithLoopOptions) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *outer, *inner, *tail; + For *inner, *tail; LoopNest l({tensor}); auto loops = NodeFinder::find(l.root_stmt()); ASSERT_GT(loops.size(), 0); l.setGPUBlockIndex(loops[0], LoopOptions::IDX_Y); - l.splitWithTail(loops[0], 4, &outer, &inner, &tail); - ASSERT_NE(outer, nullptr); + l.splitWithTail(loops[0], 4, &inner, &tail); ASSERT_NE(inner, nullptr); ASSERT_NE(tail, nullptr); + For* outer = loops[0]; // Outer loop carries loop axis bindings. ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); @@ -723,12 +703,13 @@ TEST(LoopNest, SplitWithMaskWithLoopOptions) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *outer, *inner; + For* inner; LoopNest l({tensor}); auto loops = NodeFinder::find(l.root_stmt()); l.setGPUBlockIndex(loops[0], LoopOptions::IDX_Y); - l.splitWithMask(loops[0], 4, &outer, &inner); + l.splitWithMask(loops[0], 4, &inner); + For* outer = loops[0]; // Outer loop carries loop axis bindings. ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); @@ -1305,13 +1286,11 @@ TEST(LoopNest, ScheduleSplitTwiceThenInline) { return a->load(j + ExprHandle(8)); }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* i_outer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* i_inner; LoopNest l({b}, {a, b}); std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); - l.splitWithMask(loops[0], 4, &i_outer, &i_inner); + l.splitWithMask(loops[0], 4, &i_inner); l.splitWithMask(i_inner, 2); ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); } @@ -3165,15 +3144,13 @@ TEST(LoopNest, NormalizeAndSplitWithTail) { LoopNest::normalize(for_stmt); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* x_inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* x_tail; - l.splitWithTail(for_stmt, 10, &x_outer, &x_inner, &x_tail); + l.splitWithTail(for_stmt, 10, &x_inner, &x_tail); - auto x_outer_result = IRSimplifier::simplify(x_outer); + auto x_outer_result = IRSimplifier::simplify(for_stmt); std::ostringstream oss_outer; oss_outer << *x_outer_result; const std::string& expected_outer_ir = diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index 296212ac2f8581..93177795051650 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -2995,20 +2995,12 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { { auto const& loops = loop.getLoopStmtsFor(CT); For* m = loops[0]; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* mo; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* mi; - loop.splitWithMask(m, 4, &mo, &mi); + loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); For* n = loops[2]; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* no; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* ni; - loop.splitWithMask(n, 16, &no, &ni); + loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 5d2c0f2a8a0abf..de28871bd0af03 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -624,22 +624,9 @@ TEST(Reductions, SplitNonReduceAxis) { std::vector out(16, -1.f); Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_outer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_tail; std::vector loops = l.getLoopStmtsFor(tensor); - l.splitWithTail(loops[0], 2, &x_outer, &x_inner, &x_tail); - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_2; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* x_tail_2; - l.splitWithTail(x_outer, 2, &x_2, &x_1, &x_tail_2); + l.splitWithTail(loops[0], 2); + l.splitWithTail(loops[0], 2); l.prepareForCodegen(); @@ -1133,8 +1120,8 @@ TEST(Reductions, ReduceOverSplitRfactor) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *o, *i, *t; - loop.splitWithTail(loops[1], SPLIT_FACTOR, &o, &i, &t); + For *i, *t; + loop.splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); loop.reorderAxis(loops[0], i); auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf()); @@ -1525,16 +1512,14 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { LoopNest l({e}, {c, d, e}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; // Split outer reduction axis. - l.splitWithMask(l.getLoopStmtsFor(d)[0], 4, &outer, &inner); + l.splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); // Split reduction consumer. - l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); l.cacheAccesses(d->buf(), "sum_local", inner); l.prepareForCodegen(); @@ -1576,8 +1561,6 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { LoopNest l({e}, {c, d, e}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; @@ -1586,7 +1569,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { l.reorderAxis(loops[0], loops[1]); // Split reduction consumer. - l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); l.cacheAccesses(d->buf(), "sum_local", inner); l.prepareForCodegen(); diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp index a9d7b9a4f37bc6..dcd9358b3f3909 100644 --- a/test/cpp/tensorexpr/tutorial.cpp +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -313,8 +313,6 @@ int main(int argc, char* argv[]) { // instance. std::vector loops = loopnest.getLoopStmtsFor(Y); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* j_outer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* j_inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* j_tail; @@ -322,9 +320,9 @@ int main(int argc, char* argv[]) { loopnest.splitWithTail( loops[1], // loops[0] is the outer loop, loops[1] is inner split_factor, - &j_outer, // These are handles that we would be using for &j_inner, // further transformations &j_tail); + // loops[1] will become the outer loop, j_outer, after splitWithTail. std::cout << *loopnest.root_stmt() << std::endl; // Prints: // { diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md index 0a3f940a0f37a9..9252c7fa3adc64 100644 --- a/test/cpp_api_parity/parity-tracker.md +++ b/test/cpp_api_parity/parity-tracker.md @@ -49,6 +49,7 @@ torch::nn::Hardshrink|Yes|No torch::nn::Hardtanh|Yes|No torch::nn::LeakyReLU|Yes|No torch::nn::LogSigmoid|Yes|No +torch::nn::Mish|Yes|No torch::nn::MultiheadAttention|No|No torch::nn::PReLU|Yes|No torch::nn::ReLU|Yes|No @@ -187,6 +188,7 @@ F::rrelu|Yes|No F::glu|Yes|No F::gelu|Yes|No F::silu|Yes|No +F::mish|Yes|No F::logsigmoid|Yes|No F::hardshrink|Yes|No F::tanhshrink|Yes|No diff --git a/test/distributed/_sharded_tensor/test_sharded_tensor.py b/test/distributed/_sharded_tensor/test_sharded_tensor.py new file mode 100644 index 00000000000000..84f72633314e77 --- /dev/null +++ b/test/distributed/_sharded_tensor/test_sharded_tensor.py @@ -0,0 +1,285 @@ +import torch +import torch.distributed as dist +from torch.distributed import _sharded_tensor +from torch.distributed._sharding_spec import ChunkShardingSpec +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_nccl, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + TEST_WITH_ASAN, +) +import unittest + +@unittest.skipIf( + TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues" +) +class TestShardedTensorChunked(MultiProcessTestCase): + + @property + def world_size(self): + return 4 + + def init_pg(self): + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + init_method=f"file://{self.file_name}", + ) + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_complete_world_size(self): + self.init_pg() + + for dim in [0, -2]: + spec = ChunkShardingSpec( + dim=dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + sharded_tensor = _sharded_tensor.empty(spec, 10, 20) + + # Validate local shard. + local_shards = sharded_tensor.local_shards() + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + if self.rank == 3: + self.assertEqual((1, 20), local_shard.size()) + else: + self.assertEqual((3, 20), local_shard.size()) + + # Validate global metadata. + sharding_metadata = sharded_tensor.sharding_metadata() + self.assertEqual(4, len(sharding_metadata)) + + for rank, shard_metadata in enumerate(sharding_metadata): + self.assertEqual([rank * 3, 0], shard_metadata.shard_offsets) + if rank == 3: + self.assertEqual([1, 20], shard_metadata.shard_lengths) + else: + self.assertEqual([3, 20], shard_metadata.shard_lengths) + self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_partial_world_size(self): + self.init_pg() + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + sharded_tensor = _sharded_tensor.empty(spec, 10, 20) + + # Validate local shard. + local_shards = sharded_tensor.local_shards() + if self.rank >= 2: + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual((5, 20), local_shard.size()) + else: + self.assertEqual(0, len(local_shards)) + + # Validate global metadata. + sharding_metadata = sharded_tensor.sharding_metadata() + self.assertEqual(2, len(sharding_metadata)) + + for shard_rank, shard_metadata in enumerate(sharding_metadata): + self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) + self.assertEqual([5, 20], shard_metadata.shard_lengths) + self.assertEqual(f'rank:{shard_rank + 2}/cuda:{shard_rank + 2}', shard_metadata.placement) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_new_group(self): + self.init_pg() + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:1/cuda:2", + "rank:2/cuda:3", + ], + ) + + pg = dist.new_group(ranks=[1, 2, 3]) + if self.rank >= 1: + sharded_tensor = _sharded_tensor.empty(spec, 10, 20, process_group=pg) + + # Validate local shard. + local_shards = sharded_tensor.local_shards() + if self.rank >= 2: + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual((5, 20), local_shard.size()) + else: + self.assertEqual(0, len(local_shards)) + + # Validate global metadata. + sharding_metadata = sharded_tensor.sharding_metadata() + self.assertEqual(2, len(sharding_metadata)) + + for shard_rank, shard_metadata in enumerate(sharding_metadata): + self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) + self.assertEqual([5, 20], shard_metadata.shard_lengths) + self.assertEqual(f'rank:{shard_rank + 1}/cuda:{shard_rank + 2}', shard_metadata.placement) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_multiple_local_shards(self): + self.init_pg() + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + sharded_tensor = _sharded_tensor.empty(spec, 16, 20) + + # Validate local shards. + local_shards = sharded_tensor.local_shards() + self.assertEqual(2, len(local_shards)) + for local_shard in local_shards: + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual((2, 20), local_shard.tensor.size()) + + # Validate global metadata. + sharding_metadata = sharded_tensor.sharding_metadata() + self.assertEqual(8, len(sharding_metadata)) + + for shard_idx, shard_metadata in enumerate(sharding_metadata): + self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets) + self.assertEqual([2, 20], shard_metadata.shard_lengths) + self.assertEqual(f'rank:{shard_idx % 4}/cuda:{shard_idx % 4}', shard_metadata.placement) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_sharding_columns(self): + self.init_pg() + + for dim in [1, -1]: + spec = ChunkShardingSpec( + dim=dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + + sharded_tensor = _sharded_tensor.empty(spec, 10, 32) + + # Validate local shard. + local_shards = sharded_tensor.local_shards() + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual((10, 8), local_shard.size()) + + # Validate global metadata. + sharding_metadata = sharded_tensor.sharding_metadata() + self.assertEqual(4, len(sharding_metadata)) + + for rank, shard_metadata in enumerate(sharding_metadata): + self.assertEqual([0, rank * 8], shard_metadata.shard_offsets) + self.assertEqual([10, 8], shard_metadata.shard_lengths) + self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_invalid_sharding(self): + self.init_pg() + + spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"]) + pg = dist.new_group(ranks=[2, 3]) + if self.rank < 2: + with self.assertRaisesRegex(ValueError, 'not part of process group'): + _sharded_tensor.empty(spec, 10, 20, process_group=pg) + + spec = ChunkShardingSpec(dim='H', placements=["rank:1/cuda:1"]) + with self.assertRaisesRegex(ValueError, 'needs to be an integer'): + _sharded_tensor.empty(spec, 10, 20) + + for dim in [2, 3, 4, -3, -4, -5]: + spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"]) + with self.assertRaisesRegex(ValueError, 'Invalid sharding dim'): + _sharded_tensor.empty(spec, 10, 20) + + spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"]) + with self.assertRaisesRegex(ValueError, 'Invalid rank'): + _sharded_tensor.empty(spec, 10, 20) + + spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + sharded_tensor = _sharded_tensor.empty(spec, 10, 20) + tensor = torch.empty(10, 20) + with self.assertRaisesRegex(RuntimeError, "torch function 'add' not supported for ShardedTensor!"): + torch.add(sharded_tensor, tensor) + + spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'): + _sharded_tensor.empty(spec, 10, 20, layout=torch.sparse) + + spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'): + _sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_insufficient_sharding_dims(self): + self.init_pg() + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + sharded_tensor = _sharded_tensor.empty(spec, 2, 20) + + # Validate local shard. + local_shards = sharded_tensor.local_shards() + if self.rank <= 1: + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual((1, 20), local_shard.size()) + else: + self.assertEqual(0, len(local_shards)) + + # Validate global metadata. + sharding_metadata = sharded_tensor.sharding_metadata() + self.assertEqual(2, len(sharding_metadata)) + + for shard_rank, shard_metadata in enumerate(sharding_metadata): + self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets) + self.assertEqual([1, 20], shard_metadata.shard_lengths) + self.assertEqual(f'rank:{shard_rank}/cuda:{shard_rank}', shard_metadata.placement) diff --git a/test/distributed/_sharding_spec/test_sharding_spec.py b/test/distributed/_sharding_spec/test_sharding_spec.py new file mode 100644 index 00000000000000..13591e4e23ccdd --- /dev/null +++ b/test/distributed/_sharding_spec/test_sharding_spec.py @@ -0,0 +1,218 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torch.distributed._sharding_spec import ( + ChunkShardingSpec, + DevicePlacementSpec, + EnumerableShardingSpec, + ShardMetadata, +) + +class TestShardingSpec(TestCase): + + def test_device_placement(self): + # valid devices + DevicePlacementSpec("cuda:0") + DevicePlacementSpec(0) + DevicePlacementSpec(torch.device("cuda:0")) + DevicePlacementSpec("rank:0/cuda:0") + DevicePlacementSpec("rank:0/cpu") + DevicePlacementSpec("rank:0") + + # invalid devices + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("cuda:foo") + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("foo:0") + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("rank:0/cuda:foo") + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("rank:0/cpu2") + + def test_chunked_sharding_spec(self): + # Test valid specs. + ChunkShardingSpec(0, [0, 1]) + # Named dimension. + ChunkShardingSpec("N", ["cuda:0", "cuda:1"]) + ChunkShardingSpec(0, [torch.device("cuda:0"), torch.device("cuda:1")]) + ChunkShardingSpec(-1, ["cuda:0", "cuda:1"]) + ChunkShardingSpec(0, ["rank:0/cuda:0", "rank:0/cuda:1"]) + ChunkShardingSpec(0, ["rank:0", "rank:1"]) + ChunkShardingSpec(0, ["rank:0/cpu", "rank:1/cpu"]) + + # Test invalid specs + with self.assertRaisesRegex(ValueError, "int or str"): + ChunkShardingSpec(None, ["cuda:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "int or str"): + ChunkShardingSpec({}, ["cuda:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["random:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["cuda:foo", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:foo", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:0/foo", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:0/random:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"]) + + def test_generic_sharding_spec(self): + # test valid specs + + # test row-wise sharding + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_lengths=[5, 5], + placement="cuda:1", + ) + ]) + spec.check_tensor(torch.rand(10, 5)) + + # test row and column sharding + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[3, 3], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 3], + shard_lengths=[3, 3], + placement="cuda:1", + ), + ShardMetadata( + shard_offsets=[3, 0], + shard_lengths=[3, 3], + placement="cuda:2", + ), + ShardMetadata( + shard_offsets=[3, 3], + shard_lengths=[3, 3], + placement="cuda:3", + ), + ]) + spec.check_tensor(torch.rand(6, 6)) + + # test uneven shard sizes. + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[2, 4], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 4], + shard_lengths=[4, 2], + placement="cuda:1", + ), + ShardMetadata( + shard_offsets=[2, 0], + shard_lengths=[4, 4], + placement="cuda:2", + ), + ShardMetadata( + shard_offsets=[4, 4], + shard_lengths=[2, 2], + placement="cuda:3", + ), + ]) + spec.check_tensor(torch.rand(6, 6)) + + # test invalid sharding + with self.assertRaisesRegex(ValueError, 'not a valid device'): + ShardMetadata(shard_offsets=[0], shard_lengths=[1], placement="cuda:foo") + + with self.assertRaisesRegex(ValueError, 'same number of elements'): + ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1], placement="cuda:0") + + with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): + ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0") + + with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'): + ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0") + + with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): + EnumerableShardingSpec([]) + + with self.assertRaisesRegex(ValueError, 'Found inconsistent ranks for shards'): + EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[1, 1], + placement="cpu" + ), + ShardMetadata( + shard_offsets=[0, 0, 0], + shard_lengths=[1, 1, 1], + placement="cpu" + ), + ]) + + with self.assertRaisesRegex(ValueError, 'Shards.*overlap'): + EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[3, 3], + placement="cpu" + ), + ShardMetadata( + shard_offsets=[2, 0], + shard_lengths=[3, 3], + placement="cpu" + ), + ]) + + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_lengths=[5, 5], + placement="cuda:1", + ) + ]) + + with self.assertRaisesRegex(ValueError, 'Rank of tensor is.*but shards rank'): + spec.check_tensor(torch.rand(10, 10, 10)) + + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_lengths=[5, 5], + placement="cuda:1", + ) + ]) + + with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'): + spec.check_tensor(torch.rand(10, 3)) + + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_lengths=[5, 5], + placement="cuda:1", + ) + ]) + + with self.assertRaisesRegex(ValueError, 'does not match tensor volume'): + spec.check_tensor(torch.rand(10, 10)) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 689cb6951fc800..b9158ca3a26191 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -53,6 +53,7 @@ DEFAULT_HOSTNAME = "localhost" +torch.backends.cuda.matmul.allow_tf32 = False def gpus_for_rank(world_size): """Multigpu tests are designed to simulate the multi nodes with multi @@ -596,6 +597,124 @@ def forward(self, x): return F.softmax(self.embedding(x), dim=1) +class AbstractProcessGroupWrapperTest(MultiProcessTestCase): + def setUp(self): + super(AbstractProcessGroupWrapperTest, self).setUp() + # For Windows platform, Python does not support fork, change it to spawn here. + if sys.platform == "win32": + self._spawn_processes() + else: + self._fork_processes() + + def _test_collective_hang(self, wrapper_pg, use_cuda=False): + # All ranks besides 1 call allreduce and wrapper_pg should detect a hang + # and report an issue with rank 1. + faulty_rank = 1 + if self.rank != faulty_rank: + tensor = torch.randn(20, 10) + if use_cuda: + tensor = tensor.to(self.rank) + + if self.rank == 0: + # Rank 0 reports faulty ranks + err = f"Ranks {faulty_rank} failed to pass monitoredBarrier" + else: + err = "Please check rank 0 logs for faulty rank" + with self.assertRaisesRegex(RuntimeError, err): + wrapper_pg.allreduce([tensor]) + + def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False): + tensor = torch.randn(20, 10) + if use_cuda: + tensor = tensor.to(self.rank) + works = [] + # Run a few successful collectives + for _ in range(10): + work = wrapper_pg.allreduce([tensor]) + works.append(work) + + for w in works: + w.wait() + + # Simulate mismatch: allreduce vs reduce. + with self.assertRaisesRegex( + RuntimeError, "Mismatch between collective operation types" + ): + if self.rank == 0: + wrapper_pg.allreduce([tensor]) + else: + wrapper_pg.reduce([tensor]) + + # Check additional mismatches + + with self.assertRaisesRegex( + RuntimeError, "Mismatch between collective operation types" + ): + if self.rank == 0: + wrapper_pg.reduce([tensor]) + else: + wrapper_pg.barrier() + + with self.assertRaisesRegex( + RuntimeError, "Mismatch between collective operation types" + ): + scatter_result = [torch.ones(4) * i for i in range(self.world_size)] + scattered_tensor = torch.empty(4) + if self.rank == 0: + wrapper_pg.scatter(scattered_tensor, scatter_result, 0) + else: + wrapper_pg.reduce_scatter(scattered_tensor, scatter_result) + + with self.assertRaisesRegex( + RuntimeError, "Mismatch between collective operation types" + ): + if self.rank == 0: + wrapper_pg.broadcast(tensor, 0) + else: + output_tensors = [ + torch.zeros_like(tensor) for _ in range(self.world_size) + ] + wrapper_pg.allgather([output_tensors], [tensor]) + + def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False): + wrapper_pg.barrier() + dim = 2 if self.rank == 0 else 10 + tensor = torch.randn(20, dim) + if use_cuda: + tensor = tensor.to(self.rank) + with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"): + wrapper_pg.allreduce([tensor]) + # Check errors are raised when dimensionality of shapes is different + tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10) + if use_cuda: + tensor = tensor.to(self.rank) + with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"): + wrapper_pg.allreduce([tensor]) + + # Check shape errors with scatter + input = [ + torch.tensor( + [self.rank] if self.rank == 0 else [self.rank, self.rank], + device=self.rank if use_cuda else "cpu", + ) + for _ in range(self.world_size) + ] + outputs = [ + torch.tensor( + [-1] if self.rank == 0 else [-1, -1], + device=self.rank if use_cuda else "cpu", + ) + for _ in range(self.world_size) + ] + root_rank = 0 + opts = c10d.ScatterOptions() + opts.rootRank = root_rank + with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"): + if self.rank == root_rank: + wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait() + else: + wrapper_pg.scatter([outputs[self.rank]], [], opts).wait() + class AbstractDistributedDataParallelTest(object): def tearDown(self): # DistributedDataParallel test doesn't seem to call FileStore destructor diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 1a0224c0245b06..d2f336f9352917 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -29,6 +29,7 @@ simple_sparse_reduce_tests, skip_if_win32, create_device, + with_dist_debug_levels, ) from torch.testing._internal.common_utils import ( TestCase, @@ -37,7 +38,14 @@ TEST_WITH_TSAN, ) import test_c10d_common -from test_c10d_common import LOOPBACK, gpus_for_rank, Task, ModuleForDdpCommHook, SparseGradientModule +from test_c10d_common import ( + LOOPBACK, + gpus_for_rank, + Task, + ModuleForDdpCommHook, + SparseGradientModule, + AbstractProcessGroupWrapperTest, +) def simple_reduce_tests(rank, world_size): @@ -194,6 +202,91 @@ class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): def test_default_store_timeout_gloo(self): self._test_default_store_timeout("gloo") +@requires_gloo() +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) +class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): + def setUp(self): + super(ProcessGroupGlooWrapperTest, self).setUp() + + def opts(self, threads=2, timeout=10.0): + opts = c10d.ProcessGroupGloo._Options() + opts._timeout = timeout + opts._devices = [create_device(interface=LOOPBACK)] + opts._threads = threads + return opts + + def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="gloo", rank=self.rank, world_size=self.world_size, store=store + ) + if with_new_group: + pg = c10d.new_group(backend="gloo") + else: + _pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(timeout=timeout)) + pg = c10d._create_process_group_wrapper( + _pg, + "unused", + store, + self.rank, + self.world_size, + timeout=timeout, + ) + return pg + + def test_collective_hang(self): + pg = self._create_wrapper_pg(timeout=2.0) + self._test_collective_hang(pg) + + # NOTE: these tests are separated by debug level instead of combined into + # one due to https://github.com/pytorch/pytorch/issues/55967, they can be + # combined after that is resolved. + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collectives_op_mismatch_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collectives_op_mismatch(pg) + + @with_dist_debug_levels(levels=["OFF"]) + def test_collectives_op_mismatch(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collectives_op_mismatch(pg) + + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collective_shape_mismatch_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collective_shape_mismatch(pg) + + @with_dist_debug_levels(levels=["OFF"]) + def test_collective_shape_mismatch(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collective_shape_mismatch(pg) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collectives_op_mismatch_cuda_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collectives_op_mismatch(pg, use_cuda=True) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["OFF"]) + def test_collectives_op_mismatch_cuda(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collectives_op_mismatch(pg, use_cuda=True) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collective_shape_mismatch_cuda_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collective_shape_mismatch(pg, use_cuda=True) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["OFF"]) + def test_collective_shape_mismatch_cuda(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collective_shape_mismatch(pg, use_cuda=True) @requires_gloo() @unittest.skipIf( diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 31a856d834bd14..941196341121ce 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -30,6 +30,7 @@ from torch.utils.checkpoint import checkpoint from torch.testing._internal.common_distributed import ( MultiProcessTestCase, + requires_gloo, requires_nccl, requires_nccl_version, skip_if_lt_x_gpu, @@ -45,7 +46,7 @@ TEST_WITH_TSAN, ) import test_c10d_common -from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook +from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook, AbstractProcessGroupWrapperTest class RendezvousEnvTest(TestCase): @@ -158,6 +159,87 @@ def test_default_store_timeout_nccl(self): raise unittest.SkipTest("No GPUs available, skipping test") self._test_default_store_timeout("nccl") +@requires_gloo() +@requires_nccl() +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) +class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest): + def setUp(self): + self.num_gpus = torch.cuda.device_count() + if self.num_gpus < 2: + raise unittest.SkipTest("NCCL test requires 2+ GPUs") + super(AbstractProcessGroupWrapperTest, self).setUp() + self._spawn_processes() + # NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests + # that use NCCL_BLOCKING_WAIT will test it as expected. + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + + @property + def world_size(self) -> int: + return 2 + + def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store, + timeout=timedelta(seconds=timeout) + ) + if with_new_group: + pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout)) + else: + _pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)) + pg = c10d._create_process_group_wrapper( + _pg, + "unused", + store, + self.rank, + self.world_size, + timeout=timeout, + ) + return pg + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_collective_hang(self): + pg = self._create_wrapper_pg(timeout=2.0) + self._test_collective_hang(pg) + + # NOTE: these tests are separated by debug level instead of combined into + # one due to https://github.com/pytorch/pytorch/issues/55967, they can be + # combined after that is resolved. + @requires_nccl() + @skip_if_lt_x_gpu(2) + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collectives_op_mismatch_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collectives_op_mismatch(pg, use_cuda=True) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @with_dist_debug_levels(levels=["OFF"]) + def test_collectives_op_mismatch(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collectives_op_mismatch(pg, use_cuda=True) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collective_shape_mismatch_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collective_shape_mismatch(pg, use_cuda=True) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @with_dist_debug_levels(levels=["OFF"]) + def test_collective_shape_mismatch(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collective_shape_mismatch(pg, use_cuda=True) + class ProcessGroupNCCLNoGPUTest(TestCase): MAIN_PROCESS_RANK = 0 @@ -1933,6 +2015,7 @@ def test_nccl_errors_blocking_nonzero_exit(self): @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) @skip_if_rocm + @unittest.skip("Frequently times out see https://github.com/pytorch/pytorch/issues/58920") def test_nccl_errors_blocking_abort(self): self._test_nccl_errors_blocking(lambda: os.abort()) diff --git a/test/jit/test_graph_rewrite_passes.py b/test/jit/test_graph_rewrite_passes.py new file mode 100644 index 00000000000000..d4da1e58e326c9 --- /dev/null +++ b/test/jit/test_graph_rewrite_passes.py @@ -0,0 +1,59 @@ +from torch.testing._internal.jit_utils import JitTestCase +import torch +import torch._C +from torch.testing import FileCheck + + +class TestGraphRewritePasses(JitTestCase): + def test_fuse_linear(self): + class FunctionalLinear(torch.nn.Module): + def __init__(self, weight, bias): + super(FunctionalLinear, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, x): + res = torch.matmul(x, self.weight.t()) + if self.bias is not None: + res.add_(self.bias) + return res + + x1 = torch.rand(3) + w1 = torch.rand(5, 3) + b1 = torch.rand(5) + for has_bias in [True, False]: + bias = b1 if has_bias else None + model = torch.jit.trace(FunctionalLinear(w1, bias), [x1]) + for node in model.graph.nodes(): + if node.kind() == "aten::matmul": + source_range_1 = node.sourceRange() + torch._C._jit_pass_fuse_linear(model.graph) + for node in model.graph.nodes(): + if node.kind() == "aten::linear": + source_range_2 = node.sourceRange() + FileCheck().check("aten::linear").run(model.graph) + check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("] + for cn in check_not: + FileCheck().check_not(cn).run(model.graph) + self.assertTrue(source_range_1 == source_range_2) + # make sure it runs + model(x1) + + # check matmuls are not fused + class Matmul(torch.nn.Module): + def __init__(self, weight): + super(Matmul, self).__init__() + self.weight = weight + + def forward(self, x): + return torch.matmul(x, self.weight) + + x = torch.rand(5, 6, 5) + w = torch.rand(5, 5, 100) + model = torch.jit.trace(Matmul(w), [x]) + torch._C._jit_pass_fuse_linear(model.graph) + # check 3d matmul is not fused + FileCheck().check("aten::matmul").run(model.graph) + FileCheck().check_not("aten::linear").run(model.graph) + # make sure it runs + model(x) diff --git a/test/jit/test_peephole.py b/test/jit/test_peephole.py index 2a39b1b077ba6b..97ab8d6d989834 100644 --- a/test/jit/test_peephole.py +++ b/test/jit/test_peephole.py @@ -2,6 +2,7 @@ from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything from torch import nn from torch.testing import FileCheck +from typing import List import unittest @@ -237,3 +238,294 @@ def convertible_isnot_op(x: bool, y: bool): op_graph = torch.jit.script(convertible_isnot_op).graph FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph) FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph) + + def test_peephole_list_len(self): + def run_peephole_and_check_const_value(graph, const_string): + torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True) + self.run_pass("constant_propagation", graph) + FileCheck().check(const_string).check_next("return").run(graph) + + def gen_li(inp_len: int): + return [0 for i in range(inp_len)] + + @torch.jit.script + def foo(x: List[int], y: List[int]): + if len(x) != 4 or len(y) != 5: + raise Exception("") + + return len(x) + len(y) + + run_peephole_and_check_const_value(foo.graph, "value=9") + self.assertEqual(foo(gen_li(4), gen_li(5)), 9) + with self.assertRaises(Exception): + foo(2, 4) + + @torch.jit.script + def foo(x: List[int], y: List[int]): + if len(x) == 4 and len(y) == 5: + pass + else: + raise Exception("hi") + + return len(x) + len(y) + + run_peephole_and_check_const_value(foo.graph, "value=9") + self.assertEqual(foo(gen_li(4), gen_li(5)), 9) + with self.assertRaises(Exception): + foo(2, 4) + + @torch.jit.script + def foo(x: List[int], y: List[int], z: List[int]): + if len(x) != 4: + raise Exception("..") + else: + if len(y) != 8: + raise Exception("...") + else: + if len(z) == 3: + pass + else: + raise Exception("...") + + return len(x) + len(y) * len(z) + + run_peephole_and_check_const_value(foo.graph, "value=28") + self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28) + with self.assertRaises(Exception): + foo(1, 2, 3) + + # refinement should persist in second len(x) call + + @torch.jit.script + def foo(x: List[int], cond: bool): + if len(x) == 4: + if cond: + return len(x) + return 4 + + return 4 + + run_peephole_and_check_const_value(foo.graph, "value=4") + + def test_const_tuple_output(graph, const_inputs): + tup = graph.findNode("prim::TupleConstruct") + for i, elem in enumerate(tup.inputs()): + if i in const_inputs: + self.assertIsNotNone(elem.toIValue()) + else: + self.assertIsNone(elem.toIValue()) + + # testing combinations of x1 : {True, False} x + # {then/else branch} x assert {True/False} + + @torch.jit.script + def foo(x: List[int], b: List[int]): + if len(x) == 5: + x1 = True + else: + x1 = len(b) != 4 + assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq + return len(x), len(b) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + torch._C._jit_pass_constant_propagation(foo.graph) + # we can only infer len(b) == 4 here + test_const_tuple_output(foo.graph, [1]) + + @torch.jit.script + def foo(x: List[int], b: List[int]): + if len(x) == 5: + x1 = False + else: + x1 = len(b) != 4 + assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq + return len(x), len(b) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + torch._C._jit_pass_constant_propagation(foo.graph) + # cant infer anything + test_const_tuple_output(foo.graph, []) + + @torch.jit.script + def foo(x: List[int], b: List[int]): + if len(x) == 5: + x1 = True + else: + x1 = len(b) == 4 + assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq + return len(x), len(b) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + torch._C._jit_pass_constant_propagation(foo.graph) + # we cant infer anything, only len(b) != 4 + test_const_tuple_output(foo.graph, []) + + @torch.jit.script + def foo(x: List[int], b: List[int]): + if len(x) == 5: + x1 = True + else: + x1 = len(b) != 4 + assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq + return len(x), len(b) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + torch._C._jit_pass_constant_propagation(foo.graph) + # can infer len(b) == 4 + test_const_tuple_output(foo.graph, [1]) + + # swap branches + @torch.jit.script + def foo(x: List[int], b: List[int]): + if len(x) != 5: + x1 = len(b) != 4 + else: + x1 = True + assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq + return len(x), len(b) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + torch._C._jit_pass_constant_propagation(foo.graph) + # can infer len(b) == 4 + test_const_tuple_output(foo.graph, [1]) + + # use __not__ + @torch.jit.script + def foo(x: List[int], b: List[int]): + if len(x) != 5: + x1 = len(b) != 4 + else: + x1 = True + assert not x1 + return len(x), len(b) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + torch._C._jit_pass_constant_propagation(foo.graph) + # can infer len(b) == 4 + test_const_tuple_output(foo.graph, [1]) + + # Test unsuccessful optimizations + + @torch.jit.script + def foo(x: List[int]): + assert len(x) == 4 + x.append(3) + return len(x) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + self.run_pass("constant_propagation", foo.graph) + FileCheck().check_count("aten::len", 2).run(foo.graph) + + @torch.jit.script + def foo(x: List[int], y: List[int]): + assert len(x) == 4 or len(y) == 5 + return len(x) + len(y) + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + self.run_pass("constant_propagation", foo.graph) + FileCheck().check_count("aten::len", 4).run(foo.graph) + + def test_integer_refinement(self): + def run_peephole_and_check_const_value(graph, const_string): + self.run_pass("refine_integer_values", graph) + self.run_pass("constant_propagation", graph) + self.run_pass("dce", graph) + FileCheck().check(const_string).check_next("return").run(graph) + + @torch.jit.script + def foo(x: int, y: int): + if x != 4 or y != 5: + raise Exception("") + + return x + y + + graph = foo.graph + self.run_pass("refine_integer_values", graph) + self.run_pass("constant_propagation", graph) + self.run_pass("dce", graph) + + run_peephole_and_check_const_value(foo.graph, "value=9") + self.assertEqual(foo(4, 5), 9) + with self.assertRaises(Exception): + foo(2, 4) + + @torch.jit.script + def foo(x: int, y: int): + if x == 4 and y == 5: + pass + else: + raise Exception("hi") + + return x + y + + run_peephole_and_check_const_value(foo.graph, "value=9") + self.assertEqual(foo(4, 5), 9) + with self.assertRaises(Exception): + foo(2, 4) + + @torch.jit.script + def foo(x: int, y: int, z: int): + if x != 4: + raise Exception("..") + else: + if y != 8: + raise Exception("...") + else: + if z == 3: + pass + else: + raise Exception("...") + + return x + y * z + + run_peephole_and_check_const_value(foo.graph, "value=28") + self.assertEqual(foo(4, 8, 3), 28) + with self.assertRaises(Exception): + foo(1, 2, 3) + + # refinement should persist in second len(x) call + + @torch.jit.script + def foo(x: int, cond: bool): + if x == 4: + if cond: + return x + return 4 + + return 4 + + run_peephole_and_check_const_value(foo.graph, "value=4") + + @torch.jit.script + def foo(x: int, y: int): + assert x == 4 or y == 5 + return x + y + + torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) + self.run_pass("constant_propagation", foo.graph) + FileCheck().check("aten::add").run(foo.graph) + + def test_optimize_out_comparison_same_value(self): + @torch.jit.script + def foo(x: int): + return x == x, x != x + + self.run_pass("peephole", foo.graph) + FileCheck().check_not("aten::eq").check_not("aten::neq").run(foo.graph) + self.assertEqual(foo(1), (True, False)) + + def test_refine_integer_values(self): + @torch.jit.script + def foo(x: int): + y = 1 + if x == 1: + return y + else: + return x + + self.run_pass("refine_integer_values", foo.graph) + self.run_pass("constant_propagation", foo.graph) + self.run_pass("dce", foo.graph) + FileCheck().check("graph").check_next("return").run(foo.graph) + self.assertEqual(foo(2), 2) + self.assertEqual(foo(1), 1) diff --git a/test/jit/test_script_profile.py b/test/jit/test_script_profile.py new file mode 100644 index 00000000000000..d7d1b2f092cfe8 --- /dev/null +++ b/test/jit/test_script_profile.py @@ -0,0 +1,109 @@ +import os +import sys + +import torch +from torch import nn + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +class Sequence(nn.Module): + def __init__(self): + super(Sequence, self).__init__() + self.lstm1 = nn.LSTMCell(1, 51) + self.lstm2 = nn.LSTMCell(51, 51) + self.linear = nn.Linear(51, 1) + + def forward(self, input): + outputs = [] + h_t = torch.zeros(input.size(0), 51) + c_t = torch.zeros(input.size(0), 51) + h_t2 = torch.zeros(input.size(0), 51) + c_t2 = torch.zeros(input.size(0), 51) + + for input_t in input.split(1, dim=1): + h_t, c_t = self.lstm1(input_t, (h_t, c_t)) + h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) + output = self.linear(h_t2) + outputs += [output] + outputs = torch.cat(outputs, dim=1) + return outputs + +class TestScriptProfile(JitTestCase): + + def test_basic(self): + seq = torch.jit.script(Sequence()) + p = torch.jit._ScriptProfile() + p.enable() + seq(torch.rand((10, 100))) + p.disable() + self.assertNotEqual(p.dump_string(), "") + + def test_script(self): + seq = Sequence() + + @torch.jit.script + def fn(): + p = torch.jit._ScriptProfile() + p.enable() + _ = seq(torch.rand((10, 100))) + p.disable() + return p + + self.assertNotEqual(fn().dump_string(), "") + + def test_multi(self): + seq = torch.jit.script(Sequence()) + profiles = [torch.jit._ScriptProfile() for _ in range(5)] + for p in profiles: + p.enable() + + last = None + while len(profiles) > 0: + seq(torch.rand((10, 10))) + p = profiles.pop() + p.disable() + stats = p.dump_string() + self.assertNotEqual(stats, "") + if last: + self.assertNotEqual(stats, last) + last = stats + + def test_section(self): + seq = Sequence() + + @torch.jit.script + def fn(): + p = torch.jit._ScriptProfile() + p.enable() + _ = seq(torch.rand((10, 100))) + p.disable() + stats0 = p.dump_string() + + _ = seq(torch.rand((10, 10))) + stats1 = p.dump_string() + + p.enable() + _ = seq(torch.rand((10, 10))) + p.disable() + stats2 = p.dump_string() + + p.enable() + return stats0, stats1, stats2 + + s0, s1, s2 = fn() + self.assertEqual(s0, s1) + self.assertNotEqual(s1, s2) + + def test_empty(self): + p = torch.jit._ScriptProfile() + p.enable() + p.disable() + self.assertEqual(p.dump_string(), "") diff --git a/test/jit/test_slice.py b/test/jit/test_slice.py index 5b2503f37fa43b..bbffcf7f6cd62c 100644 --- a/test/jit/test_slice.py +++ b/test/jit/test_slice.py @@ -136,3 +136,34 @@ def tuple_slice(a): self.assertTrue(num_outputs == {2}) self.run_pass('lower_all_tuples', tuple_graph) self.assertTrue('Tuple' not in str(tuple_graph)) + + def test_module_list_slicing(self): + class Bar(torch.nn.Module): + def __init__(self, identifier: str): + super().__init__() + self.identifier = identifier + + def forward(self): + return 0 + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")] + self.test = torch.nn.ModuleList(module_list) + + def forward(self): + return self.test[::-2], self.test[1:4:] + + scripted_foo = torch.jit.script(Foo()) + result1, result2 = scripted_foo() + + self.assertEqual(len(result1), 3) + self.assertEqual(result1[0].identifier, "E") + self.assertEqual(result1[1].identifier, "C") + self.assertEqual(result1[2].identifier, "A") + + self.assertEqual(len(result2), 3) + self.assertEqual(result2[0].identifier, "B") + self.assertEqual(result2[1].identifier, "C") + self.assertEqual(result2[2].identifier, "D") diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py new file mode 100644 index 00000000000000..bc41228395070e --- /dev/null +++ b/test/jit/test_symbolic_shape_analysis.py @@ -0,0 +1,95 @@ +import torch +from torch.testing._internal.jit_utils import JitTestCase + +from torch.testing import FileCheck +from typing import List + + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +# XXX: still in prototype +class TestSymbolicShapeAnalysis(JitTestCase): + def test_shape_analysis(self): + @torch.jit.script + def broadcast(a: List[int], b: List[int]): + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes : List[int] = [] + + for i in range(ndim): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if (dimA >= 0) else 1 + sizeB = b[dimB] if (dimB >= 0) else 1 + + if sizeA != sizeB and sizeA != 1 and sizeB != 1: + raise Exception("The size of tensor a {} must match the size of tensor b (" + "{}) at non-singleton dimension {}".format(sizeA, sizeB, i)) + + expandedSizes.append(sizeB if sizeA == 1 else sizeA) + + return expandedSizes + + @torch.jit.script + def foo(x, y): + return x * y + + torch._C._jit_register_operator_shape_function(foo.graph.findNode("aten::mul"), broadcast.graph) + inputs = list(foo.graph.inputs()) + + def prop_shapes_on_graph(inp0, inp1): + inputs[0].setType(inputs[0].type().with_sizes(inp0)) + inputs[1].setType(inputs[1].type().with_sizes(inp1)) + torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) + + prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5]) + FileCheck().check("1, 7, 6, 5").run(foo.graph) + + # None implicitly creates a new symbolic symbol + prop_shapes_on_graph([None, None], [None, None, None]) + output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes() + inp0_shape = inputs[0].type().symbolic_sizes() + inp1_shape = inputs[1].type().symbolic_sizes() + + # output shape dim 0 should be taken from the second inp dim0 + # other two dims we cannot infer and are given a new symbolic shape + self.assertEqual(output_shape[0], inp1_shape[0]) + self.assertFalse(output_shape[1] in inp0_shape + inp1_shape) + self.assertFalse(output_shape[2] in inp0_shape + inp1_shape) + + # XXX: symbolic shapes are represented with an increasing counter of unique + # values, use `_new_symbolic_shape_symbol` api instead of specifying negative + # dimensions directly so there is no chance of collision between manual number + # and current counter value. + sym1 = torch._C._new_symbolic_shape_symbol() + sym2 = torch._C._new_symbolic_shape_symbol() + sym3 = torch._C._new_symbolic_shape_symbol() + prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3]) + output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes() + self.assertEqual(output_shape[0], sym1) + self.assertEqual(output_shape[1], sym2) + self.assertEqual(output_shape[2], sym3) + + def test_sharing_of_list_len(self): + # testing generic sharing of logic, a la _convolution and conv2s + @torch.jit.script + def adaptive_avg_pool2d(self, out: List[int]): + assert len(out) == 2 + out2 : List[int] = [] + for elem in out: + out2.append(elem) + return out2 + + @torch.jit.script + def foo(x, out: List[int]): + return torch.nn.functional.adaptive_avg_pool2d(x, out) + + self.run_pass("inline", foo.graph) + torch._C._jit_register_operator_shape_function(foo.graph.findNode("aten::adaptive_avg_pool2d"), adaptive_avg_pool2d.graph) + torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) + FileCheck().check("Tensor(*, *)").check_same("adaptive_avg_pool2d").run(foo.graph) diff --git a/test/mobile/test_bytecode.py b/test/mobile/test_bytecode.py index 17cfe92bb4e63a..b6f1683c94b87f 100644 --- a/test/mobile/test_bytecode.py +++ b/test/mobile/test_bytecode.py @@ -8,12 +8,13 @@ from torch.jit.mobile import ( _load_for_lite_interpreter, _get_model_bytecode_version, + _get_model_ops_and_info, _backport_for_mobile_to_buffer, _backport_for_mobile) from torch.testing._internal.common_utils import TestCase, run_tests from pathlib import Path -pytorch_test_dri = Path(__file__).resolve().parents[1] +pytorch_test_dir = Path(__file__).resolve().parents[1] # script_module_v4.ptl and script_module_v5.ptl source code # class TestModule(torch.nn.Module): @@ -97,6 +98,38 @@ ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) ''' +SCRIPT_MODULE_V6_BYTECODE_PKL = ''' +(6, + ('__torch__.*.TestModule.forward', + (('instructions', + (('STOREN', 1, 2), + ('DROPR', 1, 0), + ('LOADC', 0, 0), + ('LOADC', 1, 0), + ('MOVE', 2, 0), + ('OP', 0, 0), + ('OP', 1, 0), + ('RET', 0, 0))), + ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))), + ('constants', + (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),), + 0, + (2, 4), + (4, 1), + False, + collections.OrderedDict()), + 1)), + ('types', ()), + ('register_size', 2)), + (('arguments', + ((('name', 'self'), + ('type', '__torch__.*.TestModule'), + ('default_value', None)), + (('name', 'y'), ('type', 'int'), ('default_value', None)))), + ('returns', + ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) + ''' + SCRIPT_MODULE_BYTECODE_PKL = { 4: { "bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL, @@ -113,7 +146,7 @@ def check_model_version(model_path, expect_version): actual_version = _get_model_bytecode_version(model_path) assert(actual_version == expect_version) for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items(): - model_path = pytorch_test_dri / "cpp" / "jit" / model_info["model_name"] + model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"] check_model_version(model_path, version) def test_bytecode_values_for_all_backport_functions(self): @@ -130,7 +163,7 @@ def test_bytecode_values_for_all_backport_functions(self): while current_from_version > MINIMUM_TO_VERSION: # Load model v5 and run forward method model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version]["model_name"] - input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name + input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name # A temporary model file will be export to this path, and run through bytecode.pkl # content check. @@ -205,7 +238,7 @@ def forward(self, y: int): # Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations def test_backport_bytecode_from_file_to_file(self): maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) - script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[ + script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[ maximum_checked_in_model_version]["model_name"] if (maximum_checked_in_model_version > MINIMUM_TO_VERSION): @@ -241,7 +274,7 @@ def test_backport_bytecode_from_file_to_file(self): # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations def test_backport_bytecode_from_file_to_buffer(self): maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) - script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[ + script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[ maximum_checked_in_model_version]["model_name"] if (maximum_checked_in_model_version > MINIMUM_TO_VERSION): @@ -264,5 +297,12 @@ def test_backport_bytecode_from_file_to_buffer(self): torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result) + def test_get_model_ops_and_info(self): + # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists + script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl" + ops_v6 = _get_model_ops_and_info(script_module_v6) + assert(ops_v6["aten::add.int"].num_schema_args == 2) + assert(ops_v6["aten::add.Scalar"].num_schema_args == 2) + if __name__ == '__main__': run_tests() diff --git a/test/onnx/expect/TestOperators.test_arange_dynamic.expect b/test/onnx/expect/TestOperators.test_arange_dynamic.expect index 98c35e7fab802a..6de30ee46c4992 100644 --- a/test/onnx/expect/TestOperators.test_arange_dynamic.expect +++ b/test/onnx/expect/TestOperators.test_arange_dynamic.expect @@ -3,144 +3,22 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - output: "2" - name: "Constant_1" + name: "Constant_0" op_type: "Constant" attribute { name: "value" t { - data_type: 7 - raw_data: "\000\000\000\000\000\000\000\000" + dims: 10 + data_type: 1 + raw_data: "\000\000\240@\000\000\260@\000\000\300@\000\000\320@\000\000\340@\000\000\360@\000\000\000A\000\000\010A\000\000\020A\000\000\030A" } type: TENSOR } } - node { - input: "1" - input: "2" - output: "3" - name: "Gather_2" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "0" - output: "4" - name: "Shape_3" - op_type: "Shape" - } - node { - output: "5" - name: "Constant_4" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\000\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "4" - input: "5" - output: "6" - name: "Gather_5" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - output: "7" - name: "Constant_6" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\005\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "6" - input: "7" - output: "8" - name: "Add_7" - op_type: "Add" - } - node { - input: "3" - output: "10" - name: "Cast_8" - op_type: "Cast" - attribute { - name: "to" - i: 1 - type: INT - } - } - node { - input: "8" - output: "11" - name: "Cast_9" - op_type: "Cast" - attribute { - name: "to" - i: 1 - type: INT - } - } - node { - input: "10" - input: "11" - input: "14" - output: "13" - name: "Range_10" - op_type: "Range" - } name: "torch-jit-export" - initializer { - data_type: 1 - name: "14" - raw_data: "\000\000\000?" - } - input { - name: "0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 5 - } - dim { - dim_value: 3 - } - dim { - dim_value: 2 - } - } - } - } - } output { - name: "13" + name: "1" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_empty_like.expect b/test/onnx/expect/TestOperators.test_empty_like.expect index e2560305767a26..b92cf4a92282d6 100644 --- a/test/onnx/expect/TestOperators.test_empty_like.expect +++ b/test/onnx/expect/TestOperators.test_empty_like.expect @@ -3,45 +3,23 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - input: "1" - output: "2" - name: "ConstantOfShape_1" - op_type: "ConstantOfShape" + name: "Constant_0" + op_type: "Constant" attribute { name: "value" t { - dims: 1 + dims: 5 + dims: 8 data_type: 1 - raw_data: "\000\000\000\000" + raw_data: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000" } type: TENSOR } } name: "torch-jit-export" - input { - name: "0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 5 - } - dim { - dim_value: 8 - } - } - } - } - } output { - name: "2" + name: "1" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect index 0899d05e2a8a79..f2c1cd9ac59c32 100644 --- a/test/onnx/expect/TestOperators.test_full.expect +++ b/test/onnx/expect/TestOperators.test_full.expect @@ -3,135 +3,23 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - output: "2" - name: "Constant_1" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\000\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "1" - input: "2" - output: "3" - name: "Gather_2" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "0" - output: "4" - name: "Shape_3" - op_type: "Shape" - } - node { - output: "5" - name: "Constant_4" + name: "Constant_0" op_type: "Constant" attribute { name: "value" t { - data_type: 7 - raw_data: "\001\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "4" - input: "5" - output: "6" - name: "Gather_5" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "3" - output: "7" - name: "Unsqueeze_6" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } - } - node { - input: "6" - output: "8" - name: "Unsqueeze_7" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } - } - node { - input: "7" - input: "8" - output: "9" - name: "Concat_8" - op_type: "Concat" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "9" - output: "10" - name: "ConstantOfShape_9" - op_type: "ConstantOfShape" - attribute { - name: "value" - t { - dims: 1 + dims: 3 + dims: 4 data_type: 1 - raw_data: "\000\000\000@" + raw_data: "\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@" } type: TENSOR } } name: "torch-jit-export" - input { - name: "0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 3 - } - dim { - dim_value: 4 - } - } - } - } - } output { - name: "10" + name: "1" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_full_like.expect b/test/onnx/expect/TestOperators.test_full_like.expect index 87d139dab13287..f2c1cd9ac59c32 100644 --- a/test/onnx/expect/TestOperators.test_full_like.expect +++ b/test/onnx/expect/TestOperators.test_full_like.expect @@ -3,45 +3,23 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - input: "1" - output: "2" - name: "ConstantOfShape_1" - op_type: "ConstantOfShape" + name: "Constant_0" + op_type: "Constant" attribute { name: "value" t { - dims: 1 + dims: 3 + dims: 4 data_type: 1 - raw_data: "\000\000\000@" + raw_data: "\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@" } type: TENSOR } } name: "torch-jit-export" - input { - name: "0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 3 - } - dim { - dim_value: 4 - } - } - } - } - } output { - name: "2" + name: "1" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_ones_like.expect b/test/onnx/expect/TestOperators.test_ones_like.expect index ad529a21feec0a..116576c08e34c1 100644 --- a/test/onnx/expect/TestOperators.test_ones_like.expect +++ b/test/onnx/expect/TestOperators.test_ones_like.expect @@ -3,45 +3,23 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - input: "1" - output: "2" - name: "ConstantOfShape_1" - op_type: "ConstantOfShape" + name: "Constant_0" + op_type: "Constant" attribute { name: "value" t { - dims: 1 + dims: 6 + dims: 10 data_type: 1 - raw_data: "\000\000\200?" + raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" } type: TENSOR } } name: "torch-jit-export" - input { - name: "0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 6 - } - dim { - dim_value: 10 - } - } - } - } - } output { - name: "2" + name: "1" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_slice_dynamic.expect b/test/onnx/expect/TestOperators.test_slice_dynamic.expect index 696a1eef382ed6..3335a233d87550 100644 --- a/test/onnx/expect/TestOperators.test_slice_dynamic.expect +++ b/test/onnx/expect/TestOperators.test_slice_dynamic.expect @@ -3,18 +3,13 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - output: "2" - name: "Constant_1" + name: "Constant_0" op_type: "Constant" attribute { name: "value" t { + dims: 1 data_type: 7 raw_data: "\000\000\000\000\000\000\000\000" } @@ -22,87 +17,64 @@ graph { } } node { - input: "1" - input: "2" - output: "3" - name: "Gather_2" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "0" - output: "4" - name: "Shape_3" - op_type: "Shape" - } - node { - output: "5" - name: "Constant_4" + output: "2" + name: "Constant_1" op_type: "Constant" attribute { name: "value" t { + dims: 1 data_type: 7 - raw_data: "\001\000\000\000\000\000\000\000" + raw_data: "\003\000\000\000\000\000\000\000" } type: TENSOR } } node { - input: "4" - input: "5" - output: "6" - name: "Gather_5" - op_type: "Gather" + output: "3" + name: "Constant_2" + op_type: "Constant" attribute { - name: "axis" - i: 0 - type: INT + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\377\377\377\377\377\377\377\177" + } + type: TENSOR } } node { - output: "7" - name: "Constant_6" + output: "4" + name: "Constant_3" op_type: "Constant" attribute { name: "value" t { + dims: 1 data_type: 7 - raw_data: "\003\000\000\000\000\000\000\000" + raw_data: "\001\000\000\000\000\000\000\000" } type: TENSOR } } node { - input: "6" - input: "7" - output: "8" - name: "Sub_7" - op_type: "Sub" - } - node { + input: "0" + input: "2" input: "3" - output: "11" - name: "Unsqueeze_8" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } + input: "1" + input: "4" + output: "5" + name: "Slice_4" + op_type: "Slice" } node { - output: "14" - name: "Constant_9" + output: "6" + name: "Constant_5" op_type: "Constant" attribute { name: "value" t { - dims: 1 data_type: 7 raw_data: "\001\000\000\000\000\000\000\000" } @@ -110,20 +82,10 @@ graph { } } node { - input: "0" - input: "11" - input: "17" - input: "18" - input: "14" - output: "15" - name: "Slice_10" - op_type: "Slice" - } - node { - input: "15" - input: "8" - output: "16" - name: "Gather_11" + input: "5" + input: "6" + output: "7" + name: "Gather_6" op_type: "Gather" attribute { name: "axis" @@ -132,18 +94,6 @@ graph { } } name: "torch-jit-export" - initializer { - dims: 1 - data_type: 7 - name: "17" - raw_data: "\377\377\377\377\377\377\377\177" - } - initializer { - dims: 1 - data_type: 7 - name: "18" - raw_data: "\000\000\000\000\000\000\000\000" - } input { name: "0" type { @@ -161,7 +111,7 @@ graph { } } output { - name: "16" + name: "7" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_view_flatten.expect b/test/onnx/expect/TestOperators.test_view_flatten.expect index 5ae9c0576c7a93..39b07c22095940 100644 --- a/test/onnx/expect/TestOperators.test_view_flatten.expect +++ b/test/onnx/expect/TestOperators.test_view_flatten.expect @@ -4,150 +4,18 @@ producer_version: "CURRENT_VERSION" graph { node { input: "0" - output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - output: "2" - name: "Constant_1" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\000\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "1" - input: "2" - output: "3" - name: "Gather_2" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "0" - output: "4" - name: "Shape_3" - op_type: "Shape" - } - node { - output: "5" - name: "Constant_4" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\000\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "4" - input: "5" - output: "6" - name: "Gather_5" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - output: "7" - name: "Constant_6" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\030\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "7" - input: "6" - output: "8" - name: "Div_7" - op_type: "Div" - } - node { - input: "8" - output: "9" - name: "Cast_8" - op_type: "Cast" - attribute { - name: "to" - i: 7 - type: INT - } - } - node { input: "9" - output: "10" - name: "Cast_9" - op_type: "Cast" - attribute { - name: "to" - i: 7 - type: INT - } - } - node { - input: "3" - output: "11" - name: "Unsqueeze_10" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } - } - node { - input: "10" - output: "12" - name: "Unsqueeze_11" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } - } - node { - input: "11" - input: "12" - output: "13" - name: "Concat_12" - op_type: "Concat" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "0" - input: "13" - output: "14" - name: "Reshape_13" + output: "6" + name: "Reshape_0" op_type: "Reshape" } name: "torch-jit-export" + initializer { + dims: 2 + data_type: 7 + name: "9" + raw_data: "\001\000\000\000\000\000\000\000\030\000\000\000\000\000\000\000" + } input { name: "0" type { @@ -171,7 +39,7 @@ graph { } } output { - name: "14" + name: "6" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_zeros_like.expect b/test/onnx/expect/TestOperators.test_zeros_like.expect index e2560305767a26..b92cf4a92282d6 100644 --- a/test/onnx/expect/TestOperators.test_zeros_like.expect +++ b/test/onnx/expect/TestOperators.test_zeros_like.expect @@ -3,45 +3,23 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Shape_0" - op_type: "Shape" - } - node { - input: "1" - output: "2" - name: "ConstantOfShape_1" - op_type: "ConstantOfShape" + name: "Constant_0" + op_type: "Constant" attribute { name: "value" t { - dims: 1 + dims: 5 + dims: 8 data_type: 1 - raw_data: "\000\000\000\000" + raw_data: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000" } type: TENSOR } } name: "torch-jit-export" - input { - name: "0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 5 - } - dim { - dim_value: 8 - } - } - } - } - } output { - name: "2" + name: "1" type { tensor_type { elem_type: 1 diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 3523ca997a67b5..4a12d92141bdca 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -40,13 +40,16 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field) -def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None): +def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None, + input_names=None, dynamic_axes=None): for opset_version in opset_versions: f = io.BytesIO() torch.onnx.export(module, x, f, opset_version=opset_version, training=training, - example_outputs=example_outputs) + example_outputs=example_outputs, + input_names=input_names, + dynamic_axes=dynamic_axes) model = onnx.load(io.BytesIO(f.getvalue())) check_onnx_opset_operator(model, ops[opset_version], opset_version) @@ -186,6 +189,9 @@ class DynamicSliceModel(torch.jit.ScriptModule): def forward(self, x): return x[1:x.size(0)] + module = DynamicSliceModel() + x = torch.rand(1, 2) + example_output = module(x) ops_10 = [{"op_name" : "Shape"}, {"op_name" : "Constant"}, {"op_name" : "Gather", @@ -196,9 +202,16 @@ def forward(self, x): {"op_name" : "Slice", "attributes" : []}] ops = {10 : ops_10} - module = DynamicSliceModel() - x = torch.rand(1, 2) - example_output = module(x) + check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output, + input_names=['x'], dynamic_axes={"x": [0, 1]}) + + ops_10 = [{"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Slice", + "attributes" : []}] + ops = {10 : ops_10} check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output) def test_flip(self): @@ -309,6 +322,31 @@ def forward(self, x): "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] + ops = {9 : ops_9, 10 : ops_10} + x = torch.randn(1, 2, 3, 4, requires_grad=True) + check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10], + input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}) + + ops_9 = [{"op_name" : "Constant"}, + {"op_name" : "Shape"}, + {"op_name" : "Slice"}, + {"op_name" : "Cast"}, + {"op_name" : "Div"}, + {"op_name" : "Concat"}, + {"op_name" : "Upsample", + "attributes" : + [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] + ops_10 = [{"op_name" : "Constant"}, + {"op_name" : "Shape"}, + {"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Slice"}, + {"op_name" : "Cast"}, + {"op_name" : "Div"}, + {"op_name" : "Concat"}, + {"op_name" : "Resize"}] + ops = {9 : ops_9, 10 : ops_10} x = torch.randn(1, 2, 3, 4, requires_grad=True) check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10]) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 5cdac7f6a79c2a..39fea51c829189 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -169,7 +169,9 @@ def run_debug_test(self, model, train, batch_size, state_dict=None, def run_actual_test(self, model, train, batch_size, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, example_outputs=None, do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX): + operator_export_type=torch.onnx.OperatorExportTypes.ONNX, + input_names=None, dynamic_axes=None, + remained_onnx_input_idx=None): """ This is what the user facing version will look like """ @@ -193,12 +195,17 @@ def run_actual_test(self, model, train, batch_size, state_dict=None, do_constant_folding=do_constant_folding, opset_version=self.opset_version, keep_initializers_as_inputs=True, - operator_export_type=operator_export_type) + operator_export_type=operator_export_type, + input_names=input_names, + dynamic_axes=dynamic_axes, + remained_onnx_input_idx=remained_onnx_input_idx) def run_model_test(self, model, train, batch_size, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, example_outputs=None, do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX): + operator_export_type=torch.onnx.OperatorExportTypes.ONNX, + input_names=None, dynamic_axes=None, + remained_onnx_input_idx=None): use_gpu_ = torch.cuda.is_available() and use_gpu # NOTE: do_constant_folding is turned on only when model has # parameters embedded (which are needed for constant folding), @@ -209,7 +216,10 @@ def run_model_test(self, model, train, batch_size, state_dict=None, use_gpu=use_gpu_, rtol=rtol, atol=atol, example_outputs=example_outputs, do_constant_folding=do_constant_folding, - operator_export_type=operator_export_type) + operator_export_type=operator_export_type, + input_names=input_names, + dynamic_axes=dynamic_axes, + remained_onnx_input_idx=remained_onnx_input_idx) else: self.run_debug_test(model, train, batch_size, state_dict, input, use_gpu=use_gpu_, example_outputs=example_outputs, @@ -459,11 +469,11 @@ def test_rnn_init_predict_split(self): do_constant_folding=False)[0]) prepared = c2.prepare(mp, device="CPU") if self.embed_params: - assert len(prepared.init_net.op) == 879 - assert len(prepared.predict_net.op) == 133 + assert len(prepared.init_net.op) == 950 + assert len(prepared.predict_net.op) == 101 else: - assert len(prepared.init_net.op) == 12 - assert len(prepared.predict_net.op) == 1000 + assert len(prepared.init_net.op) == 83 + assert len(prepared.predict_net.op) == 968 def test_alexnet(self): state_dict = model_zoo.load_url(model_urls["alexnet"], progress=False) @@ -1177,7 +1187,9 @@ def forward(self, x, y): x = torch.randn(1, 2, requires_grad=True) y = torch.randn(2, 4, requires_grad=True) - self.run_model_test(MyModel(), train=False, input=(x, y), batch_size=BATCH_SIZE, use_gpu=False) + self.run_model_test(MyModel(), train=False, input=(x, y), batch_size=BATCH_SIZE, use_gpu=False, + input_names=["x", "y"], dynamic_axes={"x": [0, 1], "y": [0, 1]}) + self.run_model_test(MyModel(), train=False, input=(x, y), batch_size=BATCH_SIZE, use_gpu=False, remained_onnx_input_idx=[0]) def test_mean(self): shape = (3, 4, 5) @@ -1228,7 +1240,7 @@ class MyModule(torch.nn.Module): def forward(self, x): return (torch.randn(1, 2, 3, 4) + x).shape self.run_model_test(MyModule(), train=False, input=(x), - batch_size=BATCH_SIZE, use_gpu=False) + batch_size=BATCH_SIZE, use_gpu=False, remained_onnx_input_idx=[]) def test_rand(self): x = torch.randn(1, 2, 3, 4) @@ -1237,7 +1249,7 @@ class MyModule(torch.nn.Module): def forward(self, x): return (torch.rand(1, 2, 3, 4) + x).shape self.run_model_test(MyModule(), train=False, input=(x), - batch_size=BATCH_SIZE, use_gpu=False) + batch_size=BATCH_SIZE, use_gpu=False, remained_onnx_input_idx=[]) def test_convtranspose(self): model = nn.ConvTranspose2d(3, 3, 3, stride=3, bias=False, padding=1, output_padding=2) @@ -1418,7 +1430,9 @@ def forward(self, x): return x.zero_() x = torch.randn(2, 3, 4) - self.run_model_test(Zero_(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False) + self.run_model_test(Zero_(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False, + input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) + self.run_model_test(Zero_(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_inplace_fill(self): @@ -1427,7 +1441,9 @@ def forward(self, x): return x.fill_(3) x = torch.randn(2, 3, 4) - self.run_model_test(Fill_(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False) + self.run_model_test(Fill_(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False, + input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) + self.run_model_test(Fill_(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False, remained_onnx_input_idx=[]) # ConstantFill is a deprecated experimental op (used in opsets < 9). # Shape inference does not cover this op. @@ -1454,7 +1470,10 @@ def forward(self, x): return torch.zeros(x.size()) + torch.ones(x.size()) x = torch.randn(2, 3, 4) - self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False) + self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, + use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) + self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, + use_gpu=False, remained_onnx_input_idx=[]) def test_tensor_factories_script(self): class TensorFactory(torch.jit.ScriptModule): @@ -1464,7 +1483,11 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, - use_gpu=False, example_outputs=(torch.ones(x.size()),)) + use_gpu=False, example_outputs=(torch.ones(x.size()),), + input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) + self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, + use_gpu=False, example_outputs=(torch.ones(x.size()),), + remained_onnx_input_idx=[]) def test_tensor_like_factories_script(self): class TensorFactory(torch.jit.ScriptModule): @@ -1476,7 +1499,12 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, - use_gpu=False, example_outputs=(torch.ones(x.size()),)) + use_gpu=False, example_outputs=(torch.ones(x.size()),), + input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) + remained_onnx_input_idx = None if self.opset_version < 9 else [] + self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE, + use_gpu=False, example_outputs=(torch.ones(x.size()),), + remained_onnx_input_idx=remained_onnx_input_idx) def test_full(self): class FullModel(torch.nn.Module): @@ -2101,7 +2129,9 @@ def forward(self, input): return torch.arange(input.size(0)), torch.arange(input.size(-1)) x = torch.randn(5, 3, 2) - self.run_model_test(SizeModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + self.run_model_test(SizeModel(), train=False, input=(x,), batch_size=BATCH_SIZE, + input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) + self.run_model_test(SizeModel(), train=False, input=(x,), batch_size=BATCH_SIZE, remained_onnx_input_idx=[]) def test_log2(self): class Log2Model(torch.nn.Module): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 94546c3033dee7..86f7d774b03c5c 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -82,12 +82,15 @@ def run_ort(ort_sess, input): return inline_flatten_list(ort_outs, []) -def ort_compare_with_pytorch(ort_outs, output, rtol, atol): +def ort_compare_with_pytorch(ort_outs, output, rtol, atol, check_dtypes=False): output, _ = torch.jit._flatten(output) outputs = [to_numpy(outp) for outp in output] # compare onnxruntime and PyTorch results assert len(outputs) == len(ort_outs), "number of outputs differ" + if check_dtypes: + assert all(l.dtype == r.dtype for l, r in zip(outputs, ort_outs)), \ + "dtypes of outputs differ" # compare onnxruntime and PyTorch results [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] @@ -99,7 +102,7 @@ def run_model_test(self, model, batch_size=2, state_dict=None, dynamic_axes=None, test_with_inputs=None, input_names=None, output_names=None, fixed_batch_size=False, dict_check=True, - training=None): + training=None, check_dtypes=False, remained_onnx_input_idx=None): model.eval() if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) @@ -133,8 +136,13 @@ def run_model_test(self, model, batch_size=2, state_dict=None, output_names=output_names, fixed_batch_size=fixed_batch_size, training=training, onnx_shape_inference=self.onnx_shape_inference) # compute onnxruntime output prediction + if remained_onnx_input_idx is not None: + input_onnx = [] + for idx in remained_onnx_input_idx: + input_onnx.append(input[idx]) + input = input_onnx ort_outs = run_ort(ort_sess, input) - ort_compare_with_pytorch(ort_outs, output, rtol, atol) + ort_compare_with_pytorch(ort_outs, output, rtol, atol, check_dtypes) # if additional test inputs are provided run the onnx @@ -147,8 +155,13 @@ def run_model_test(self, model, batch_size=2, state_dict=None, output = model(*test_input_copy) if isinstance(output, torch.Tensor): output = (output,) + if remained_onnx_input_idx is not None: + test_input_onnx = [] + for idx in remained_onnx_input_idx: + test_input_onnx.append(input[idx]) + test_input = test_input_onnx ort_outs = run_ort(ort_sess, test_input) - ort_compare_with_pytorch(ort_outs, output, rtol, atol) + ort_compare_with_pytorch(ort_outs, output, rtol, atol, check_dtypes) def _init_test_generalized_rcnn_transform(): min_size = 100 @@ -238,23 +251,35 @@ def setUp(self): os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0" self.is_script_test_enabled = True + # The exported ONNX model may have less inputs than the pytorch model because of const folding. + # This mostly happens in unit test, where we widely use torch.size or torch.shape. + # So the output is only dependent on the input shape, not value. + # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model. def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None, input_names=None, output_names=None, fixed_batch_size=False, dict_check=True, - training=None): - def _run_test(m): + training=None, check_dtypes=False, remained_onnx_input_idx=None): + def _run_test(m, remained_onnx_input_idx): return run_model_test(self, m, batch_size=batch_size, input=input, use_gpu=use_gpu, rtol=rtol, atol=atol, do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs, input_names=input_names, output_names=output_names, fixed_batch_size=fixed_batch_size, dict_check=dict_check, - training=training) + training=training, check_dtypes=check_dtypes, remained_onnx_input_idx=remained_onnx_input_idx) + + if isinstance(remained_onnx_input_idx, dict): + scripting_remained_onnx_input_idx = remained_onnx_input_idx['scripting'] + tracing_remained_onnx_input_idx = remained_onnx_input_idx['tracing'] + else: + scripting_remained_onnx_input_idx = remained_onnx_input_idx + tracing_remained_onnx_input_idx = remained_onnx_input_idx + if self.is_script_test_enabled: script_model = torch.jit.script(model) - _run_test(script_model) + _run_test(script_model, scripting_remained_onnx_input_idx) - _run_test(model) + _run_test(model, tracing_remained_onnx_input_idx) def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7, example_outputs=None, do_constant_folding=True, @@ -695,7 +720,8 @@ def forward(self, x): return x.new_zeros(x.data.size()) x = torch.randn(3, 4) - self.run_test(Data(), x) + self.run_test(Data(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(Data(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # Need type inference @@ -992,7 +1018,8 @@ def forward(self, x): return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64) x = torch.randn(3, 4) - self.run_test(MyModel(), x) + self.run_test(MyModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(MyModel(), x, remained_onnx_input_idx=[]) def test_scalar_tensor(self): class test(torch.nn.Module): @@ -1014,7 +1041,8 @@ def forward(self, input): return torch.tensor(input.shape[1]) x = torch.randn(3, 4) - self.run_test(ScalarInputModel(), x) + self.run_test(ScalarInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(ScalarInputModel(), x, remained_onnx_input_idx=[]) class TensorInputModel(torch.jit.ScriptModule): @torch.jit.script_method @@ -1022,7 +1050,8 @@ def forward(self, input): return torch.tensor([input.shape[0], input.shape[1]]) x = torch.randn(3, 4) - self.run_test(TensorInputModel(), x) + self.run_test(TensorInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(TensorInputModel(), x, remained_onnx_input_idx=[]) class FloatInputModel(torch.jit.ScriptModule): @torch.jit.script_method @@ -1038,7 +1067,8 @@ def forward(self, input): return torch.tensor(input.shape[1], dtype=torch.long) x = torch.randn(3, 4) - self.run_test(InputWithDtypeModel(), x) + self.run_test(InputWithDtypeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(InputWithDtypeModel(), x, remained_onnx_input_idx=[]) class MixedInputModel(torch.jit.ScriptModule): @torch.jit.script_method @@ -1438,7 +1468,8 @@ def forward(self, x): return x.new_zeros(x.shape) x = torch.randn(2, 3, 4) - self.run_test(FloatingPoint(), x) + self.run_test(FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[]) class FloatingPoint(torch.jit.ScriptModule): @torch.jit.script_method @@ -1468,7 +1499,8 @@ def forward(self, x): return x x = torch.randn(2, 3, 4) - self.run_test(FloatingPoint(), x) + self.run_test(FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[]) class FloatingPoint(torch.jit.ScriptModule): @torch.jit.script_method @@ -1541,7 +1573,7 @@ def forward(self, x): return x.shape[0] x = torch.randn(2, 3, 4) - self.run_test(ArithmeticModule(), x) + self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[]) def test_arithmetic_prim_float(self): class ArithmeticModule(torch.nn.Module): @@ -1563,7 +1595,7 @@ def forward(self, x): return x.shape[1] / 2 x = torch.randn(2, 3, 4) - self.run_test(ArithmeticModule(), x) + self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[]) def test_arithmetic_prim_bool(self): class ArithmeticModule(torch.nn.Module): @@ -1648,7 +1680,8 @@ def forward(self, x): return x.new_zeros(x.size(2) // x.size(1)) x = torch.randn(2, 3, 4) - self.run_test(FloordivModule(), (x,)) + self.run_test(FloordivModule(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(FloordivModule(), (x,), remained_onnx_input_idx=[]) def test_div(self): class DivModule(torch.nn.Module): @@ -1832,7 +1865,8 @@ def forward(self, x): return x.new_zeros(x.shape[1:x.size(2)]) x = torch.rand(1, 2, 3, 4) - self.run_test(DynamicSliceModel(), x) + self.run_test(DynamicSliceModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}) + self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(10) @disableScriptTest() # scripting tuple/list append @@ -1897,7 +1931,9 @@ def forward(self, start, end): x = torch.randn(2, 3, 4) y = torch.tensor(8) - self.run_test(ArangeStartOutModel(), (x, y)) + self.run_test(ArangeStartOutModel(), (x, y), + input_names=["x", "y"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1]) @skipIfUnsupportedMinOpsetVersion(11) def test_arange(self): @@ -1907,7 +1943,9 @@ def forward(self, start, end): x = torch.randn(2, 3, 4) y = torch.tensor(8.5, dtype=torch.float) - self.run_test(ArangeModel(), (x, y)) + self.run_test(ArangeModel(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2]}) + self.run_test(ArangeModel(), (x, y), remained_onnx_input_idx=[1]) @skipIfUnsupportedMinOpsetVersion(11) def test_arange_out(self): @@ -1928,7 +1966,9 @@ def forward(self, start, end): x = torch.randn(2, 3, 4) y = torch.tensor(8.5, dtype=torch.float) - self.run_test(ArangeStartOutModel(), (x, y)) + self.run_test(ArangeStartOutModel(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2]}) + self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1]) @skipIfUnsupportedMinOpsetVersion(11) def test_arange_no_type(self): @@ -1947,7 +1987,8 @@ def forward(self, input): return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape) x = torch.randn(5, 3, 2) - self.run_test(SizeModel(), x) + self.run_test(SizeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(SizeModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() # x.stride() not scriptable @@ -2507,7 +2548,8 @@ def forward(self, x, y): x = torch.randn(1, 2, 4, 4, requires_grad=True) y = torch.randn(16, 16, requires_grad=True) - self.run_test(MyModel(), (x, y)) + self.run_test(MyModel(), (x, y), input_names=["x", "y"], dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1]}) + self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0]) # scripting will throw the OnnxRuntimeError @disableScriptTest() @@ -2555,7 +2597,8 @@ def forward(self, x): return x.new_zeros((a, b)) x = torch.randn(2, 3) - self.run_test(ListUnpack(), x) + self.run_test(ListUnpack(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(ListUnpack(), x, remained_onnx_input_idx=[]) class ListUnpackSlice(torch.jit.ScriptModule): @torch.jit.script_method @@ -2564,7 +2607,8 @@ def forward(self, x): return x.new_zeros((a, b)) x = torch.randn(2, 3, 4, 5) - self.run_test(ListUnpackSlice(), x) + self.run_test(ListUnpackSlice(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}) + self.run_test(ListUnpackSlice(), x, remained_onnx_input_idx=[]) def test_pow(self): class PowModule(torch.nn.Module): @@ -3341,15 +3385,17 @@ def forward(self, src, index): @skipIfUnsupportedMinOpsetVersion(9) def test_one_hot(self): class OneHot(torch.nn.Module): - def __init__(self, num_classes): + def __init__(self, num_classes, dtype): super().__init__() self.num_classes = num_classes + self.dtype = dtype def forward(self, x): - return torch.nn.functional.one_hot(x, self.num_classes) + return torch.nn.functional.one_hot(x, self.num_classes, self.dtype) x = torch.arange(10) - self.run_test(OneHot(15), (x)) + self.run_test(OneHot(15, torch.long), (x), check_dtypes=True) + self.run_test(OneHot(15, torch.uint8), (x), check_dtypes=True) @skipIfUnsupportedMinOpsetVersion(9) def test_gather(self): @@ -3828,13 +3874,15 @@ def forward(self, a): x = torch.randn(3, 4, requires_grad=True) outputs = ArangeScript()(x) - self.run_test(ArangeScript(), x) + self.run_test(ArangeScript(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(ArangeScript(), x, remained_onnx_input_idx=[]) class ArangeModel(torch.nn.Module): def forward(self, a): return torch.arange(a.size(0)) - self.run_test(ArangeModel(), x) + self.run_test(ArangeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(ArangeModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_arange_start_end(self): @@ -3907,7 +3955,9 @@ def forward(self, input): return torch._dim_arange(input, 1) x = torch.ones(5, 6) - self.run_test(DimArange(), x) + self.run_test(DimArange(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}) + remained_onnx_input_idx = None if self.opset_version < 11 else [] + self.run_test(DimArange(), x, remained_onnx_input_idx=remained_onnx_input_idx) def _test_compare_ops(self, model, num_inputs): x_float = torch.randn(1, 2, 3, 4, requires_grad=True) @@ -4213,7 +4263,9 @@ def forward(self, input, other): x = torch.randn(2, 3, 4) shape = torch.randn(6, 4) - self.run_test(ViewModel(), (x, shape)) + self.run_test(ViewModel(), (x, shape), + input_names=["x", "shape"], dynamic_axes={"x": [0, 1, 2], "shape": [0, 1]}) + self.run_test(ViewModel(), (x, shape), remained_onnx_input_idx=[0]) def test_view_dynamic_zero_dim(self): class ViewModel(torch.nn.Module): @@ -4423,7 +4475,7 @@ def forward(self, input): return torch.ones(len(input.shape)) x = torch.randn(4, 5) - self.run_test(LenListModel(), x) + self.run_test(LenListModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) def test_unbind_dynamic(self): @@ -4516,7 +4568,9 @@ def forward(self, x, y, t): x = torch.randn(2, 3) y = torch.randn(2, 4) t = torch.randn(2, 7) - self.run_test(SplitModule(), (x, y, t)) + self.run_test(SplitModule(), (x, y, t), input_names=["x", "y", "t"], + dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]}) + self.run_test(SplitModule(), (x, y, t), remained_onnx_input_idx=[2]) @skipIfUnsupportedMinOpsetVersion(11) def test_split_dynamic(self): @@ -4898,7 +4952,8 @@ def forward(self, x): return torch.zeros(x.size()) + torch.ones(x.size()) x = torch.randn(2, 3, 4) - self.run_test(TensorFactory(), x) + self.run_test(TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(TensorFactory(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_factories_script(self): @@ -4908,7 +4963,8 @@ def forward(self, x): return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float) x = torch.randn(2, 3, 4) - self.run_test(TensorFactory(), x) + self.run_test(TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(TensorFactory(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_like_factories_script(self): @@ -4920,7 +4976,8 @@ def forward(self, x): return zeros + ones x = torch.randn(2, 3, 4) - self.run_test(TensorFactory(), x) + self.run_test(TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(TensorFactory(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_eye(self): @@ -4942,7 +4999,8 @@ def forward(self, x): return x.zero_(), x x = torch.randn(2, 3, 4) - self.run_test(Zero_(), x) + self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(Zero_(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_zeros(self): @@ -4951,7 +5009,8 @@ def forward(self, x): return x.new_zeros(x.shape[1:2]), x.new_zeros(x.shape[2:], dtype=torch.long) x = torch.randn(2, 3, 4) - self.run_test(Zero_(), x) + self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(Zero_(), x, remained_onnx_input_idx=[]) @skipIfONNXShapeInference(True) @skipIfUnsupportedMinOpsetVersion(9) @@ -4976,7 +5035,9 @@ def forward(self, x, y): x = torch.randn(2, 3, 4, 5) y = torch.randn(1, 2, 3, 4) - self.run_test(Slice(), (x, y)) + self.run_test(Slice(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]}) + self.run_test(Slice(), (x, y), remained_onnx_input_idx=[]) class Size(torch.nn.Module): def forward(self, x, y): @@ -4984,7 +5045,9 @@ def forward(self, x, y): x = torch.randn(2, 3, 4) y = torch.randn(1, 2, 3) - self.run_test(Size(), (x, y)) + self.run_test(Size(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}) + self.run_test(Size(), (x, y), remained_onnx_input_idx=[]) class Array(torch.nn.Module): def forward(self, x, y): @@ -4994,7 +5057,9 @@ def forward(self, x, y): x = torch.randn(2, 3, 4) y = torch.randn(1, 2, 3) - self.run_test(Array(), (x, y)) + self.run_test(Array(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}) + self.run_test(Array(), (x, y), remained_onnx_input_idx=[]) class List(torch.nn.Module): def forward(self, x, y): @@ -5004,7 +5069,9 @@ def forward(self, x, y): x = torch.randn(2, 3, 4) y = torch.randn(1, 2, 3) - self.run_test(List(), (x, y)) + self.run_test(List(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]}) + self.run_test(List(), (x, y), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_empty(self): @@ -5013,7 +5080,8 @@ def forward(self, x): return x.new_empty(x.shape[0]).fill_(0), x.new_empty(x.shape[0], dtype=torch.long) * 0 x = torch.randn(2, 3, 4) - self.run_test(Emtpy(), x) + self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(Emtpy(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_full(self): @@ -5022,7 +5090,8 @@ def forward(self, x): return x.new_full(x.shape[1:2], 5), x.new_full(x.shape[0:1], 1.3, dtype=torch.long) x = torch.randn(2, 3, 4) - self.run_test(Full(), x) + self.run_test(Full(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(Full(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_inplace_list(self): @@ -5033,7 +5102,9 @@ def forward(self, x, y): x = torch.randn(2, 3) y = torch.randn(2, 3) - self.run_test(Arithmetic(), (x, y)) + self.run_test(Arithmetic(), (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1], "y": [0, 1]}) + self.run_test(Arithmetic(), (x, y), remained_onnx_input_idx=[0]) @skipIfUnsupportedMinOpsetVersion(9) def test_inplace_fill(self): @@ -5042,7 +5113,8 @@ def forward(self, x): return x.fill_(3), x x = torch.randn(2, 3, 4) - self.run_test(Fill_(), x) + self.run_test(Fill_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(Fill_(), x, remained_onnx_input_idx=[]) def test_inplace_arithmetic(self): class Arithmetic(torch.jit.ScriptModule): @@ -5487,6 +5559,18 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(SiLUModel(), (x)) + def test_mish(self): + class MishModel(torch.nn.Module): + def __init__(self): + super(MishModel, self).__init__() + self.mish = torch.nn.Mish() + + def forward(self, x): + return self.mish(x) + + x = torch.randn(2, 3, 4) + self.run_test(MishModel(), (x)) + def test_remainder(self): class RemainderModel(torch.nn.Module): def forward(self, input, other): @@ -6450,6 +6534,35 @@ def forward(self, input, target): target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedMinOpsetVersion(12) + def test_nllloss_dynamic_ignore_index(self): + import torch.nn.functional as F + + def linear_combination(x, y, epsilon): + return epsilon * x + (1 - epsilon) * y + + def reduce_loss(loss, reduction='mean'): + return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss + + class LabelSmoothingCrossEntropy(torch.nn.Module): + def __init__(self, epsilon: float = 0.1, reduction='mean'): + super().__init__() + self.epsilon = epsilon + self.reduction = reduction + + def forward(self, preds, target, start_position): + n = preds.size()[-1] + log_preds = F.log_softmax(preds, dim=-1) + ignore_index = start_position.size(1) + nll = F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=ignore_index) + return nll + start_position.float() + + N = 5 + preds = torch.randn(N, 16) + target = torch.randint(5, (N,)) + start_position = torch.randint(10, (N, N)) + self.run_test(LabelSmoothingCrossEntropy(), (preds, target, start_position)) + @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_mean_ignore_index_weights(self): class NLLModel(torch.nn.Module): @@ -7863,7 +7976,8 @@ def forward(self, x): return self.module(x) x = torch.randn(3, 256) - self.run_test(Module(), (x, )) + self.run_test(Module(), (x, ), input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(Module(), (x, ), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_set_attr_modules_2(self): @@ -7898,7 +8012,8 @@ def forward(self, x): return self.module(x) x = torch.randn(3, 256) - self.run_test(Module(), (x, )) + self.run_test(Module(), (x, ), input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(Module(), (x, ), remained_onnx_input_idx=[]) def test_set_attr(self): class MyModule(torch.nn.Module): @@ -8097,7 +8212,10 @@ def forward(self, input_data, prev_state): model = Example(10) random_data = torch.rand((1, 5, 30, 30)) empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) - self.run_test(model, (random_data, empty_tensor)) + self.run_test(model, (random_data, empty_tensor), + input_names=["random_data", "empty_tensor"], + dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}) + self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_if_2(self): @@ -8139,6 +8257,7 @@ def forward(self, input_data, prev_state): input_names=["data", "state"], dynamic_axes={"state": [0, 1, 2, 3, 4]}, test_with_inputs=[(random_data, random_state)]) + self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_if_3(self): @@ -8172,7 +8291,10 @@ def forward(self, input_data, prev_state): model = Example(4) random_data = torch.rand((1, 5, 4, 4)) empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) - self.run_test(model, (random_data, empty_tensor)) + self.run_test(model, (random_data, empty_tensor), + input_names=["random_data", "empty_tensor"], + dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}) + self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_if_4(self): @@ -8206,7 +8328,10 @@ def forward(self, input_data, prev_state): model = Example(4) random_data = torch.rand((1, 5, 4, 4)) empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) - self.run_test(model, (random_data, empty_tensor)) + self.run_test(model, (random_data, empty_tensor), + input_names=["random_data", "empty_tensor"], + dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}) + self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) @@ -8241,7 +8366,10 @@ def forward(self, input_data, prev_state): model = Example(4) random_data = torch.rand((1, 5, 4, 4)) empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0) - self.run_test(model, (random_data, empty_tensor)) + self.run_test(model, (random_data, empty_tensor), + input_names=["random_data", "empty_tensor"], + dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]}) + self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) def test_list_append_in_block(self): @@ -8359,7 +8487,9 @@ def forward(self, input_data): model = Example(10) random_data = torch.rand((1, 5, 30, 30)) - self.run_test(model, (random_data)) + self.run_test(model, (random_data), input_names=["random_data"], + dynamic_axes={"random_data": [0, 1, 2, 3]}) + self.run_test(model, (random_data), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) def test_input_mask_model(self): @@ -8428,7 +8558,9 @@ def forward(self, x, y): model.eval() x = torch.ones(2, 3, 4, 5) y = torch.ones(3, 4, 5, 2) - self.run_test(model, (x, y)) + self.run_test(model, (x, y), input_names=["x", "y"], + dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]}) + self.run_test(model, (x, y), remained_onnx_input_idx=[1]) class ViewModel(torch.nn.Module): def forward(self, x): @@ -8475,7 +8607,8 @@ def forward(self, signal): M, C, K, N = 1, 2, 3, 4 x = torch.randint(5, (M, C, K, N)) y = torch.randint(5, (M, C + 1, K + 1, N + 1)) - self.run_test(model, x) + self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}) + self.run_test(model, x, remained_onnx_input_idx=[]) self.run_test(model, x, input_names=["x"], dynamic_axes={"x" : [0, 1, 2, 3]}, test_with_inputs=[(x,), (y,)]) @@ -8524,13 +8657,15 @@ class ArangeModel(torch.nn.Module): def forward(self, start): return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64) x = torch.randn(2, 3, 4) - self.run_test(ArangeModel(), (x,)) + self.run_test(ArangeModel(), (x,), input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[]) class ArangeModel2(torch.nn.Module): def forward(self, start): return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double) x = torch.randn(2, 3, 4) - self.run_test(ArangeModel2(), (x,)) + self.run_test(ArangeModel2(), (x,), input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_symbolic_shape_inference_nonzero(self): @@ -8540,9 +8675,11 @@ def forward(self, x): return torch.nonzero(ones) x = torch.randn(2) - self.run_test(OneLikeModel(), x) + self.run_test(OneLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0]}) + self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[]) x = torch.randn(2, 3, 4) - self.run_test(OneLikeModel(), x) + self.run_test(OneLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[]) class ZeroLikeModel(torch.nn.Module): def forward(self, x): @@ -8550,9 +8687,11 @@ def forward(self, x): return torch.nonzero(zeros) x = torch.randn(2) - self.run_test(ZeroLikeModel(), x) + self.run_test(ZeroLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0]}) + self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[]) x = torch.randn(2, 3, 4) - self.run_test(ZeroLikeModel(), x) + self.run_test(ZeroLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_symbolic_shape_inference_expand_1(self): @@ -8573,7 +8712,8 @@ def forward(self, x): causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] return causal_mask.transpose(0, 1) x = torch.randn(3, 16) - self.run_test(M(), (x,)) + self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}) + self.run_test(M(), (x,), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(10) @disableScriptTest() # Test code not scriptable @@ -8586,7 +8726,9 @@ def forward(self, x, position_bias): return position_bias.transpose(0, 1) x = torch.randn(3, 16) position_bias = torch.randn(1, 3, 20, 8) - self.run_test(M(), (x, position_bias)) + self.run_test(M(), (x, position_bias), input_names=["x", "position_bias"], + dynamic_axes={"x": [0, 1], "position_bias": [0, 1, 2, 3]}) + self.run_test(M(), (x, position_bias), remained_onnx_input_idx=[1]) def test_symbolic_shape_inference_slice_2(self): class M(torch.nn.Module): diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index bd513c8e2cb9cb..71f52b306b8c46 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -3,7 +3,7 @@ import torch import torch.onnx from torch.onnx import utils, OperatorExportTypes, TrainingMode -from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type +from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type, _set_onnx_shape_inference import torch.utils.cpp_extension from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion import caffe2.python.onnx.backend as backend @@ -33,14 +33,23 @@ def _model_to_graph(self, model, input, do_constant_folding=True, example_outputs=None, training=TrainingMode.EVAL, - operator_export_type=OperatorExportTypes.ONNX): - - return utils._model_to_graph(model, input, - do_constant_folding=do_constant_folding, - _disable_torch_constant_prop=True, - operator_export_type=operator_export_type, - training=training, - example_outputs=example_outputs) + operator_export_type=OperatorExportTypes.ONNX, + input_names=None, + dynamic_axes=None): + + # Need disable onnx_shape_inference for this test because it puts const node to initializers. + _set_onnx_shape_inference(False) + utils._validate_dynamic_axes(dynamic_axes, model, None, None) + graph, params_dict, torch_out = utils._model_to_graph(model, input, + do_constant_folding=do_constant_folding, + _disable_torch_constant_prop=True, + operator_export_type=operator_export_type, + training=training, + example_outputs=example_outputs, + input_names=input_names, + dynamic_axes=dynamic_axes) + _set_onnx_shape_inference(True) + return graph, params_dict, torch_out def test_is_in_onnx_export(self): test_self = self @@ -83,7 +92,8 @@ def forward(self, x, y, t): x = torch.randn(2, 3) y = torch.randn(2, 4) t = torch.randn(2, 7) - graph, _, _ = utils._model_to_graph(SplitModule(), (x, y, t)) + graph, _, _ = self._model_to_graph(SplitModule(), (x, y, t), input_names=['x', 'y', 't'], + dynamic_axes={'x': [0, 1], 'y': [0, 1], 't': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::SplitToSequence" @@ -117,7 +127,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(3, 2) - graph, _, __ = self._model_to_graph(TransposeModule(), (x, )) + graph, _, __ = self._model_to_graph(TransposeModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Transpose" @@ -135,7 +146,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) - graph, _, __ = self._model_to_graph(ReduceModule(), (x, )) + graph, _, __ = self._model_to_graph(ReduceModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::ReduceL2" @@ -151,7 +163,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) - graph, _, __ = self._model_to_graph(NormModule(), (x, )) + graph, _, __ = self._model_to_graph(NormModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::ReduceL1" @@ -167,7 +180,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) - graph, _, __ = self._model_to_graph(NarrowModule(), (x, )) + graph, _, __ = self._model_to_graph(NarrowModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" @@ -185,7 +199,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) - graph, _, __ = self._model_to_graph(SliceIndexExceedsDimModule(), (x, )) + graph, _, __ = self._model_to_graph(SliceIndexExceedsDimModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" @@ -205,7 +220,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) - graph, _, __ = self._model_to_graph(SliceNegativeIndexModule(), (x, )) + graph, _, __ = self._model_to_graph(SliceNegativeIndexModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" @@ -225,7 +241,8 @@ def forward(self, x): x = torch.ones(1, 3) model = GatherModule() model(x) - graph, _, __ = self._model_to_graph(GatherModule(), (x, )) + graph, _, __ = self._model_to_graph(GatherModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Gather" @@ -240,7 +257,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 2, 3) - graph, _, __ = self._model_to_graph(UnsqueezeModule(), (x, )) + graph, _, __ = self._model_to_graph(UnsqueezeModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1, 2]}) for node in graph.nodes(): assert node.kind() != "onnx::Unsqueeze" @@ -261,7 +279,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3, 4, 5, 8, 7) - graph, _, __ = self._model_to_graph(PReluModel(), x) + graph, _, __ = self._model_to_graph(PReluModel(), x, input_names=['x'], + dynamic_axes={'x': [0, 1, 2, 3, 4, 5]}) for node in graph.nodes(): assert node.kind() != "onnx::Unsqueeze" @@ -278,8 +297,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) - graph, _, __ = self._model_to_graph(SqueezeModule(), (x, )) - print(graph) + graph, _, __ = self._model_to_graph(SqueezeModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Squeeze" assert node.kind() != "onnx::Cast" @@ -295,7 +314,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) - graph, _, __ = self._model_to_graph(SqueezeAxesModule(), (x, )) + graph, _, __ = self._model_to_graph(SqueezeAxesModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Squeeze" @@ -330,7 +350,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) - graph, _, __ = self._model_to_graph(ConcatModule(), (x, )) + graph, _, __ = self._model_to_graph(ConcatModule(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Concat" @@ -351,7 +372,8 @@ def forward(self, input, initial_state): _set_operator_export_type(OperatorExportTypes.ONNX) input = torch.randn(5, 3, 7) h0 = torch.randn(1, 3, 3) - graph, _, __ = self._model_to_graph(GruNet(), (input, h0)) + graph, _, __ = self._model_to_graph(GruNet(), (input, h0), input_names=['input', 'h0'], + dynamic_axes={'input': [0, 1, 2], 'h0': [0, 1, 2]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" @@ -376,7 +398,8 @@ def forward(self, A): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) A = torch.randn(2, 3) - graph, _, __ = self._model_to_graph(MatMulNet(), (A, )) + graph, _, __ = self._model_to_graph(MatMulNet(), (A, ), + input_names=['A'], dynamic_axes={'A': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Transpose" @@ -395,7 +418,8 @@ def forward(self, x): _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(4, 5) - graph, _, __ = self._model_to_graph(ReshapeModule(), (x, )) + graph, _, __ = self._model_to_graph(ReshapeModule(), (x, ), + input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Reshape" @@ -414,7 +438,8 @@ def forward(self, x): x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, _, __ = self._model_to_graph(Module(), (x, )) + graph, _, __ = self._model_to_graph(Module(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Div" @@ -433,7 +458,8 @@ def forward(self, x): x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, _, __ = self._model_to_graph(Module(), (x, )) + graph, _, __ = self._model_to_graph(Module(), (x, ), input_names=['x'], + dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Mul" @@ -452,9 +478,10 @@ def forward(self, x): x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, params_dict, __ = utils._model_to_graph( + graph, params_dict, __ = self._model_to_graph( Module(), (x, ), do_constant_folding=True, - operator_export_type=OperatorExportTypes.ONNX) + operator_export_type=OperatorExportTypes.ONNX, + input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): self.assertTrue(node.kind() != "onnx::Add") self.assertEqual(len(list(graph.nodes())), 1) @@ -477,9 +504,9 @@ def forward(self, x): x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, params_dict, __ = utils._model_to_graph( + graph, params_dict, __ = self._model_to_graph( Module(), (x, ), do_constant_folding=True, - operator_export_type=OperatorExportTypes.ONNX) + operator_export_type=OperatorExportTypes.ONNX, input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Sub" self.assertEqual(len(list(graph.nodes())), 1) @@ -502,7 +529,7 @@ def forward(self, x): x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, _, __ = self._model_to_graph(Module(), (x, )) + graph, _, __ = self._model_to_graph(Module(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Sqrt" assert len(list(graph.nodes())) == 1 @@ -520,7 +547,7 @@ def forward(self, x): x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, _, __ = self._model_to_graph(ShapeModule(), (x, )) + graph, _, __ = self._model_to_graph(ShapeModule(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Shape" @@ -613,7 +640,8 @@ def forward(self, x): x = torch.randn(2, 3, 4) _set_opset_version(self.opset_version) graph, _, __ = self._model_to_graph(Module(), (x, ), - operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH) + operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, + input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) iter = graph.nodes() assert next(iter).kind() == "onnx::Constant" assert next(iter).kind() == "aten::triu" @@ -647,7 +675,9 @@ def forward(self, input, other): y = torch.randn(2, 3, 4, requires_grad=False) model = FooModel() graph, _, __ = self._model_to_graph(model, (x, y), - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH) + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, + input_names=['x', 'y'], + dynamic_axes={'x': [0, 1, 2], 'y': [0, 1, 2]}) iter = graph.nodes() assert next(iter).kind() == "custom_namespace::custom_op" @@ -658,7 +688,9 @@ def test_onnx_fallthrough(self): model = torch.nn.CosineSimilarity(dim=1, eps=1e-6) graph, _, __ = self._model_to_graph(model, (x, y), - operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH) + operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, + input_names=['x', 'y'], + dynamic_axes={'x': [0, 1], 'y': [0, 1]}) iter = graph.nodes() assert next(iter).kind() == "onnx::Constant" assert next(iter).kind() == "onnx::Constant" @@ -687,7 +719,9 @@ def forward(self, x): output = q_model(*pt_inputs) graph, _, __ = self._model_to_graph(q_model, pt_inputs, example_outputs=output, - operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH) + operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, + input_names=['pt_inputs'], + dynamic_axes={'pt_inputs': [0, 1, 2, 3]}) iter = graph.nodes() assert next(iter).kind() == "onnx::Constant" @@ -714,7 +748,8 @@ def forward(self, x): output = model(x) model.eval() graph, _, __ = self._model_to_graph(model, (x,), example_outputs=output, - operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH) + operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, + input_names=['x'], dynamic_axes={'x': [0]}) iter = graph.nodes() assert next(iter).kind() == "prim::ListConstruct" @@ -735,7 +770,8 @@ def forward(self, input): model = Custom() batch = torch.FloatTensor(1, 3) - graph, _, _ = utils._model_to_graph(model, batch) + graph, _, _ = self._model_to_graph(model, batch, + input_names=['batch'], dynamic_axes={'batch': [0, 1]}) iter = graph.nodes() assert next(iter).kind() == "CustomNamespace::Custom" @@ -753,8 +789,10 @@ def forward(self, x): x = torch.randn(20, 16, 50, 100) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - _, params_dict, __ = utils._model_to_graph(Model(), (x, ), do_constant_folding=False, - operator_export_type=OperatorExportTypes.ONNX) + _, params_dict, __ = self._model_to_graph(Model(), (x, ), do_constant_folding=False, + operator_export_type=OperatorExportTypes.ONNX, + input_names=['x'], + dynamic_axes={'x': [0, 1, 2, 3]}) assert len(params_dict) == 2 @@ -776,8 +814,9 @@ def forward(self, x): f = io.BytesIO() _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, - operator_export_type=OperatorExportTypes.ONNX) + graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, + operator_export_type=OperatorExportTypes.ONNX, + input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]}) graph_input_params = [param.debugName() for param in graph.inputs()] assert all(item in graph_input_params for item in dict(model.named_parameters())), \ @@ -810,7 +849,8 @@ def forward(self, x): x = torch.randn(2, 3, 2, 2, requires_grad=True) graph, _, __ = self._model_to_graph(Fuse(), (x, ), - training=TrainingMode.EVAL) + training=TrainingMode.EVAL, input_names=['x'], + dynamic_axes={'x': [0, 1, 2, 3]}) for node in graph.nodes(): assert node.kind() != "onnx::BatchNormalization" assert node.kind() == "onnx::Conv" @@ -820,7 +860,8 @@ def forward(self, x): def test_fuse_resnet18(self): model = torchvision.models.resnet18(pretrained=True) x = torch.randn(2, 3, 224, 224, requires_grad=True) - graph, _, __ = self._model_to_graph(model, (x, )) + graph, _, __ = self._model_to_graph(model, (x, ), + input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]}) for node in graph.nodes(): assert node.kind() != "onnx::BatchNormalization" @@ -844,8 +885,10 @@ def forward(self, x, y): input_2 = torch.tensor(12) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) - graph, _, __ = utils._model_to_graph(MyModule(), (input_1, input_2), do_constant_folding=True, - operator_export_type=OperatorExportTypes.ONNX) + graph, _, __ = self._model_to_graph(MyModule(), (input_1, input_2), do_constant_folding=True, + operator_export_type=OperatorExportTypes.ONNX, + input_names=['input_1', 'input_2'], + dynamic_axes={'input_1': [0], 'input_2': [0]}) # Check that the prim::Constant node in the graph for representing the # scripted function `f` is removed and the following prim::CallFunction # is replced by inline graph, with onnx::Sub and onnx::Add nodes. diff --git a/test/onnx/verify.py b/test/onnx/verify.py index 652cb50b14ecaf..9c972253d1c3f3 100644 --- a/test/onnx/verify.py +++ b/test/onnx/verify.py @@ -228,7 +228,9 @@ def __exit__(self, exc_type, exc_value, traceback): def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-7, test_args=2, do_constant_folding=True, example_outputs=None, opset_version=None, keep_initializers_as_inputs=True, add_node_names=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX): + operator_export_type=torch.onnx.OperatorExportTypes.ONNX, + input_names=None, dynamic_axes=None, + remained_onnx_input_idx=None): """ Export a model into ONNX, import it into a specified ONNX backend, and then on a few random inputs verify that PyTorch and the backend produced the same @@ -273,6 +275,9 @@ def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode operator_export_type (enum, default OperatorExportTypes.ONNX): the operator export type to use when exporting the model. The default value converts all operators to ONNX ops. + input_names (list of string): list of input names. + dynamic_axes (dict of (string, list)): dynamic_axes. + remained_onnx_input_idx (list of int, default None): The remained ONNX input index. """ def _nested_map(condition, fn, condition_msg=None): def _map(obj): @@ -353,13 +358,15 @@ def load_bytes(b): opset_version=opset_version, keep_initializers_as_inputs=keep_initializers_as_inputs, add_node_names=add_node_names, - operator_export_type=operator_export_type) + operator_export_type=operator_export_type, + input_names=input_names, + dynamic_axes=dynamic_axes) if isinstance(model, torch.jit.ScriptModule): torch_out = model(*args) proto = load_bytes(proto_bytes) prepared = backend.prepare(proto) - def run(args): + def run(args, remained_onnx_input_idx): alt_proto_bytes = io.BytesIO() torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose, do_constant_folding=do_constant_folding, @@ -367,7 +374,9 @@ def run(args): opset_version=opset_version, keep_initializers_as_inputs=keep_initializers_as_inputs, add_node_names=add_node_names, - operator_export_type=operator_export_type) + operator_export_type=operator_export_type, + input_names=input_names, + dynamic_axes=dynamic_axes) if isinstance(model, torch.jit.ScriptModule): torch_out = model(*args) alt_proto = load_bytes(alt_proto_bytes) @@ -434,11 +443,17 @@ def run(args): raise AssertionError() # TODO: test that the traced model also returns the same thing... - run_helper(torch_out, args) + run_helper(torch_out, args, remained_onnx_input_idx) # Factored out so we can avoid one run of the model - def run_helper(torch_out, args): - backend_out = prepared.run(backend_args(args)) + def run_helper(torch_out, args, remained_onnx_input_idx): + onnx_input = backend_args(args) + if remained_onnx_input_idx is not None: + input_onnx = [] + for idx in remained_onnx_input_idx: + input_onnx.append(onnx_input[idx]) + onnx_input = tuple(input_onnx) + backend_out = prepared.run(onnx_input) if isinstance(torch_out, torch.Tensor): torch_out = (torch_out,) torch_out, _ = torch._C._jit_flatten(torch_out) @@ -451,11 +466,11 @@ def run_helper(torch_out, args): for i, (x, y) in enumerate(zip(torch_out, backend_out)): errs.checkAlmostEqual(x.data.cpu().numpy(), y, "In output {}".format(i)) - run_helper(torch_out, args) + run_helper(torch_out, args, remained_onnx_input_idx) if isinstance(test_args, int): for i in range(test_args): - run(randomize_args(args)) + run(randomize_args(args), remained_onnx_input_idx) else: for test_arg in test_args: - run(test_arg) + run(test_arg, remained_onnx_input_idx) diff --git a/test/package/common.py b/test/package/common.py index 4e022be449f0a2..f522c37e178948 100644 --- a/test/package/common.py +++ b/test/package/common.py @@ -2,8 +2,8 @@ import sys from tempfile import NamedTemporaryFile -from torch.testing._internal.common_utils import IS_WINDOWS, TestCase import torch.package.package_exporter +from torch.testing._internal.common_utils import IS_WINDOWS, TestCase class PackageTestCase(TestCase): @@ -31,7 +31,6 @@ def setUp(self): sys.path.append(self.package_test_dir) torch.package.package_exporter._gate_torchscript_serialization = False - def tearDown(self): super().tearDown() sys.path = self.orig_sys_path diff --git a/test/package/package_a/fake_script_class.py b/test/package/package_a/fake_script_class.py index c171c3b684b775..d3278d3f9f5eaf 100644 --- a/test/package/package_a/fake_script_class.py +++ b/test/package/package_a/fake_script_class.py @@ -1,14 +1,17 @@ import torch + @torch.jit.script class MyScriptClass: """Intended to be scripted.""" + def __init__(self, x): self.foo = x def set_foo(self, x): self.foo = x + @torch.jit.script def uses_script_class(x): """Intended to be scripted.""" diff --git a/test/package/package_b/__init__.py b/test/package/package_b/__init__.py index b0dce250ef0111..3b6edeb41ae006 100644 --- a/test/package/package_b/__init__.py +++ b/test/package/package_b/__init__.py @@ -4,6 +4,7 @@ result = "package_b" + class PackageBObject: __slots__ = ["obj"] diff --git a/test/package/package_b/subpackage_0/subsubpackage_0/__init__.py b/test/package/package_b/subpackage_0/subsubpackage_0/__init__.py index 93252cf2300378..8189db3f6b0a88 100644 --- a/test/package/package_b/subpackage_0/subsubpackage_0/__init__.py +++ b/test/package/package_b/subpackage_0/subsubpackage_0/__init__.py @@ -2,5 +2,6 @@ result = "subsubpackage_0" + class PackageBSubsubpackage0Object_0: pass diff --git a/test/package/package_b/subpackage_1.py b/test/package/package_b/subpackage_1.py index d397753f288f3d..fb66acdfa00524 100644 --- a/test/package/package_b/subpackage_1.py +++ b/test/package/package_b/subpackage_1.py @@ -1,5 +1,6 @@ result = "subpackage_1" + class PackageBSubpackage1Object_0: __slots__ = ["obj"] diff --git a/test/package/package_c/test_module.py b/test/package/package_c/test_module.py index 73a6d1f8f4cabd..7516f7c3f5fd4e 100644 --- a/test/package/package_c/test_module.py +++ b/test/package/package_c/test_module.py @@ -1,4 +1,5 @@ import torch + try: from torchvision.models import resnet18 @@ -11,8 +12,10 @@ def forward(self, x): x = a_non_torch_leaf(x, x) return torch.relu(x + 3.0) + except ImportError: pass + def a_non_torch_leaf(a, b): return a + b diff --git a/test/package/test_analyze.py b/test/package/test_analyze.py index 80218ac00ad16f..59f0f9b8509337 100644 --- a/test/package/test_analyze.py +++ b/test/package/test_analyze.py @@ -1,6 +1,5 @@ import torch from torch.package import analyze - from torch.testing._internal.common_utils import run_tests try: @@ -23,5 +22,6 @@ def test_trace_dependencies(self): self.assertNotIn("yaml", used_modules) self.assertIn("test_trace_dep", used_modules) + if __name__ == "__main__": run_tests() diff --git a/test/package/test_importer.py b/test/package/test_importer.py index 95d92ae729803a..432392d431498a 100644 --- a/test/package/test_importer.py +++ b/test/package/test_importer.py @@ -106,7 +106,9 @@ class DummyClass: dummy_importer_foo = DummyImporter("foo") dummy_importer_bar = DummyImporter("bar") - dummy_importer_not_found = DummyImporter("__main__") # __main__ is used as a proxy for "not found" by CPython + dummy_importer_not_found = DummyImporter( + "__main__" + ) # __main__ is used as a proxy for "not found" by CPython foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar) self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo") @@ -114,7 +116,9 @@ class DummyClass: bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo) self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar") - notfound_then_foo = OrderedImporter(dummy_importer_not_found, dummy_importer_foo) + notfound_then_foo = OrderedImporter( + dummy_importer_not_found, dummy_importer_foo + ) self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo") def test_package_importer_whichmodule_no_dunder_module(self): diff --git a/test/package/test_mangling.py b/test/package/test_mangling.py index abfa90e92ce5cb..ae8a7eb1562ea4 100644 --- a/test/package/test_mangling.py +++ b/test/package/test_mangling.py @@ -1,3 +1,5 @@ +from io import BytesIO + from torch.package import PackageExporter, PackageImporter from torch.package._mangling import ( PackageMangler, @@ -7,8 +9,6 @@ ) from torch.testing._internal.common_utils import run_tests -from io import BytesIO - try: from .common import PackageTestCase except ImportError: diff --git a/test/package/test_misc.py b/test/package/test_misc.py index 13ae46c97254a8..aa4cf1cb89edd6 100644 --- a/test/package/test_misc.py +++ b/test/package/test_misc.py @@ -75,7 +75,6 @@ def test_file_structure(self): he.save_pickle("obj", "obj.pkl", obj) he.save_text("main", "main", "my string") - buffer.seek(0) hi = PackageImporter(buffer) @@ -85,9 +84,7 @@ def test_file_structure(self): dedent("\n".join(str(file_structure).split("\n")[1:])), export_plain, ) - file_structure = hi.file_structure( - include=["**/subpackage.py", "**/*.pkl"] - ) + file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"]) self.assertEqual( dedent("\n".join(str(file_structure).split("\n")[1:])), export_include, @@ -177,9 +174,7 @@ def test_dunder_package_present(self): buffer.seek(0) pi = PackageImporter(buffer) - mod = pi.import_module( - "package_a.subpackage" - ) + mod = pi.import_module("package_a.subpackage") self.assertTrue(hasattr(mod, "__torch_package__")) def test_dunder_package_works_from_package(self): @@ -198,9 +193,7 @@ def test_dunder_package_works_from_package(self): buffer.seek(0) pi = PackageImporter(buffer) - imported_mod = pi.import_module( - mod.__name__ - ) + imported_mod = pi.import_module(mod.__name__) self.assertTrue(imported_mod.is_from_package()) self.assertFalse(mod.is_from_package()) diff --git a/test/package/test_model.py b/test/package/test_model.py index 9af73507807533..104965584760c5 100644 --- a/test/package/test_model.py +++ b/test/package/test_model.py @@ -27,7 +27,10 @@ class ModelTest(PackageTestCase): """End-to-end tests packaging an entire model.""" - @skipIf(IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode") + @skipIf( + IS_FBCODE or IS_SANDCASTLE, + "Tests that use temporary files are disabled in fbcode", + ) def test_resnet(self): resnet = resnet18() diff --git a/test/package/test_resources.py b/test/package/test_resources.py index 68768302eb5d18..204d5ed50227f1 100644 --- a/test/package/test_resources.py +++ b/test/package/test_resources.py @@ -120,6 +120,29 @@ def test_importer_access(self): self.assertEqual(m.t, "my string") self.assertEqual(m.b, "my string".encode("utf-8")) + def test_resource_access_by_path(self): + """ + Tests that packaged code can used importlib.resources.path. + """ + buffer = BytesIO() + with PackageExporter(buffer, verbose=False) as he: + he.save_binary("string_module", "my_string", "my string".encode("utf-8")) + src = dedent( + """\ + import importlib.resources + import string_module + + with importlib.resources.path(string_module, 'my_string') as path: + with open(path, mode='r', encoding='utf-8') as f: + s = f.read() + """ + ) + he.save_source_string("main", src, is_package=True) + buffer.seek(0) + hi = PackageImporter(buffer) + m = hi.import_module("main") + self.assertEqual(m.s, "my string") + if __name__ == "__main__": run_tests() diff --git a/test/package/test_save_load.py b/test/package/test_save_load.py index cda4bc94f90f45..729f35e8f8d124 100644 --- a/test/package/test_save_load.py +++ b/test/package/test_save_load.py @@ -20,7 +20,10 @@ class TestSaveLoad(PackageTestCase): """Core save_* and loading API tests.""" - @skipIf(IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode") + @skipIf( + IS_FBCODE or IS_SANDCASTLE, + "Tests that use temporary files are disabled in fbcode", + ) def test_saving_string(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: @@ -39,7 +42,10 @@ def test_saving_string(self): my_mod = hi.import_module("my_mod") self.assertIs(my_mod.math, math) - @skipIf(IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode") + @skipIf( + IS_FBCODE or IS_SANDCASTLE, + "Tests that use temporary files are disabled in fbcode", + ) def test_save_module(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: @@ -60,6 +66,7 @@ def test_dunder_imports(self): buffer = BytesIO() with PackageExporter(buffer, verbose=False) as he: import package_b + obj = package_b.PackageBObject he.intern("**") he.save_pickle("res", "obj.pkl", obj) @@ -103,7 +110,10 @@ def test_save_module_binary(self): self.assertEqual(package_a_i.result, "package_a") self.assertIsNot(package_a_i, package_a) - @skipIf(IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode") + @skipIf( + IS_FBCODE or IS_SANDCASTLE, + "Tests that use temporary files are disabled in fbcode", + ) def test_pickle(self): import package_a.subpackage @@ -129,7 +139,10 @@ def test_pickle(self): package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject ) - @skipIf(IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode") + @skipIf( + IS_FBCODE or IS_SANDCASTLE, + "Tests that use temporary files are disabled in fbcode", + ) def test_save_imported_module_fails(self): """ Directly saving/requiring an PackageImported module should raise a specific error message. @@ -151,7 +164,10 @@ def test_save_imported_module_fails(self): with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"): pe.save_module(loaded1.__module__) - @skipIf(IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode") + @skipIf( + IS_FBCODE or IS_SANDCASTLE, + "Tests that use temporary files are disabled in fbcode", + ) def test_exporting_mismatched_code(self): """ If an object with the same qualified name is loaded from different diff --git a/test/package/test_torchscript.py b/test/package/test_torchscript.py index ba62d7ffb233c7..f59ef868b46582 100644 --- a/test/package/test_torchscript.py +++ b/test/package/test_torchscript.py @@ -1,10 +1,15 @@ -from unittest import skipIf from io import BytesIO from tempfile import TemporaryDirectory +from unittest import skipIf import torch from torch.package import PackageExporter, PackageImporter -from torch.testing._internal.common_utils import run_tests, IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS +from torch.testing._internal.common_utils import ( + run_tests, + IS_FBCODE, + IS_SANDCASTLE, + IS_WINDOWS, +) try: from torchvision.models import resnet18 diff --git a/test/package/test_trace_dep/__init__.py b/test/package/test_trace_dep/__init__.py index b375608ce8b4c3..d6fb2e8216770b 100644 --- a/test/package/test_trace_dep/__init__.py +++ b/test/package/test_trace_dep/__init__.py @@ -1,6 +1,7 @@ import torch import yaml + class SumMod(torch.nn.Module): def forward(self, inp): return torch.sum(inp) diff --git a/test/quantization/test_deprecated_jit_quant.py b/test/quantization/test_deprecated_jit_quant.py new file mode 100644 index 00000000000000..d98778f2d1dedf --- /dev/null +++ b/test/quantization/test_deprecated_jit_quant.py @@ -0,0 +1,355 @@ +import torch +from torch.testing._internal.common_quantization import ( + skipIfNoFBGEMM +) +from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.common_utils import suppress_warnings +from torch.testing._internal.jit_utils import JitTestCase + +from typing import Tuple +import copy +import io + +# TODO: Move some tensor tests here like test_serialize_qtensor to test_quantize_tensor.py +class TestDeprecatedJitQuantized(JitTestCase): + @skipIfNoFBGEMM + def test_rnn_cell_quantized(self): + d_in, d_hid = 2, 2 + + for cell in [ + torch.nn.LSTMCell(d_in, d_hid).float(), + torch.nn.GRUCell(d_in, d_hid).float(), + torch.nn.RNNCell(d_in, d_hid).float(), + ]: + if isinstance(cell, torch.nn.LSTMCell): + num_chunks = 4 + elif isinstance(cell, torch.nn.GRUCell): + num_chunks = 3 + elif isinstance(cell, torch.nn.RNNCell): + num_chunks = 1 + + # Replace parameter values s.t. the range of values is exactly + # 255, thus we will have 0 quantization error in the quantized + # GEMM call. This i s for testing purposes. + # + # Note that the current implementation does not support + # accumulation values outside of the range representable by a + # 16 bit integer, instead resulting in a saturated value. We + # must take care that in our test we do not end up with a dot + # product that overflows the int16 range, e.g. + # (255*127+255*127) = 64770. So, we hardcode the test values + # here and ensure a mix of signedness. + vals = [[100, -155], + [100, -155], + [-155, 100], + [-155, 100], + [100, -155], + [-155, 100], + [-155, 100], + [100, -155]] + vals = vals[:d_hid * num_chunks] + cell.weight_ih = torch.nn.Parameter( + torch.tensor(vals, dtype=torch.float), + requires_grad=False) + cell.weight_hh = torch.nn.Parameter( + torch.tensor(vals, dtype=torch.float), + requires_grad=False) + + ref = copy.deepcopy(cell) + + cell = torch.jit.quantized.quantize_rnn_cell_modules(cell) + x = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float) + h0_vals = [[-155, 100], + [-155, 155], + [100, -155]] + hx = torch.tensor(h0_vals, dtype=torch.float) + if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): + cx = torch.tensor(h0_vals, dtype=torch.float) + hiddens = (hx, cx) + else: + hiddens = hx + + if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super(ScriptWrapper, self).__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x: torch.Tensor, + hiddens: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.cell(x, hiddens) + else: + + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super(ScriptWrapper, self).__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor: + return self.cell(x, hiddens) + + cell = ScriptWrapper(cell) + outs = cell(x, hiddens) + cell = self.getExportImportCopyWithPacking(cell) + + outs = cell(x, hiddens) + ref_outs = ref(x, hiddens) + + self.assertEqual(len(outs), len(ref_outs)) + for out, ref_out in zip(outs, ref_outs): + torch.testing.assert_allclose(out, ref_out) + + @skipIfNoFBGEMM + def test_rnn_quantized(self): + d_in, d_hid = 2, 2 + + for cell in [ + torch.nn.LSTM(d_in, d_hid).float(), + torch.nn.GRU(d_in, d_hid).float(), + ]: + + # Replace parameter values s.t. the range of values is exactly + # 255, thus we will have 0 quantization error in the quantized + # GEMM call. This i s for testing purposes. + # + # Note that the current implementation does not support + # accumulation values outside of the range representable by a + # 16 bit integer, instead resulting in a saturated value. We + # must take care that in our test we do not end up with a dot + # product that overflows the int16 range, e.g. + # (255*127+255*127) = 64770. So, we hardcode the test values + # here and ensure a mix of signedness. + vals = [[100, -155], + [100, -155], + [-155, 100], + [-155, 100], + [100, -155], + [-155, 100], + [-155, 100], + [100, -155]] + if isinstance(cell, torch.nn.LSTM): + num_chunks = 4 + elif isinstance(cell, torch.nn.GRU): + num_chunks = 3 + vals = vals[:d_hid * num_chunks] + cell.weight_ih_l0 = torch.nn.Parameter( + torch.tensor(vals, dtype=torch.float), + requires_grad=False) + cell.weight_hh_l0 = torch.nn.Parameter( + torch.tensor(vals, dtype=torch.float), + requires_grad=False) + + ref = copy.deepcopy(cell) + cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8) + cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16) + + niter = 10 + x = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) + h0_vals = [[-155, 100], + [-155, 155], + [100, -155]] + hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) + cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) + + if isinstance(ref, torch.nn.LSTM): + hiddens = (hx, cx) + elif isinstance(ref, torch.nn.GRU): + hiddens = hx + + ref_out, ref_hid = ref(x, hiddens) + + # Compare int8 quantized to unquantized + output_int8, final_hiddens_int8 = cell_int8(x, hiddens) + + torch.testing.assert_allclose(output_int8, ref_out) + for out, ref in zip(final_hiddens_int8, ref_hid): + torch.testing.assert_allclose(out, ref) + + # Compare fp16 quantized to unquantized + output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) + + torch.testing.assert_allclose(output_fp16, ref_out) + for out, ref in zip(final_hiddens_fp16, ref_hid): + torch.testing.assert_allclose(out, ref) + + def compare_quantized_unquantized(ScriptWrapper, cell): + wrapper = ScriptWrapper(cell) + + # Compare quantize scripted module to unquantized + script_out, script_hid = wrapper(x, hiddens) + torch.testing.assert_allclose(script_out, ref_out) + for out, ref in zip(script_hid, ref_hid): + torch.testing.assert_allclose(out, ref) + + # Compare export/import to unquantized + export_import_wrapper = self.getExportImportCopyWithPacking(wrapper) + ei_out, ei_hid = export_import_wrapper(x, hiddens) + torch.testing.assert_allclose(ei_out, ref_out) + for out, ref in zip(ei_hid, ref_hid): + torch.testing.assert_allclose(out, ref) + + if isinstance(cell, torch.jit.quantized.QuantizedGRU): + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super(ScriptWrapper, self).__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.cell(x, hiddens) + + compare_quantized_unquantized(ScriptWrapper, cell) + elif isinstance(cell, torch.jit.quantized.QuantizedLSTM): + for cell in [cell_int8, cell_fp16]: + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super(ScriptWrapper, self).__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x, hiddens): + # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) + # -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + return self.cell(x, hiddens) + compare_quantized_unquantized(ScriptWrapper, cell) + + if 'fbgemm' in torch.backends.quantized.supported_engines: + # Suppression: using deprecated quant api + @suppress_warnings + def test_quantization_modules(self): + K1, N1 = 2, 2 + + class FooBar(torch.nn.Module): + def __init__(self): + super(FooBar, self).__init__() + self.linear1 = torch.nn.Linear(K1, N1).float() + + def forward(self, x): + x = self.linear1(x) + return x + + fb = FooBar() + fb.linear1.weight = torch.nn.Parameter( + torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False) + fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False) + + x = (torch.rand(1, K1).float() - 0.5) / 10.0 + value = torch.tensor([[100, -150]], dtype=torch.float) + + y_ref = fb(value) + + fb_int8 = torch.jit.quantized.quantize_linear_modules(fb) + traced_int8 = torch.jit.trace(fb_int8, (x,)) + fb_int8 = self.getExportImportCopyWithPacking(traced_int8) + y_int8 = fb_int8(value) + + fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16) + traced_fp16 = torch.jit.trace(fb_fp16, (x,)) + fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16) + y_fp16 = fb_fp16(value) + + torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3) + torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3) + + def _test_pickle_checkpoint_qtensor(self, device): + with TemporaryFileName() as fname: + class M(torch.jit.ScriptModule): + __constants__ = ['fname'] + + def __init__(self): + super(M, self).__init__() + self.fname = fname + + @torch.jit.script_method + def forward(self, x, y): + torch.save((x, y), self.fname) + return y + + q = torch.quantize_per_tensor( + torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device) + qc = torch.quantize_per_channel( + torch.rand(2, 3, dtype=torch.float), + scales=torch.tensor([0.1, 0.5, 0.01]), + zero_points=torch.tensor([10, 0, 20]), + axis=1, dtype=torch.quint8).to(device) + m = M() + m(q, qc) + with open(fname, "rb") as handle: + loaded_q, loaded_qc = torch.load(fname) + self.assertEqual(loaded_q, q) + self.assertEqual(loaded_qc, qc) + + def test_pickle_checkpoint_qtensor(self): + self._test_pickle_checkpoint_qtensor('cpu') + + def test_serialize_qtensor(self): + class SimpleQTensor(torch.jit.ScriptModule): + def __init__(self, per_channel): + super(SimpleQTensor, self).__init__() + x = torch.rand(5, 5).float() + if not per_channel: + x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8) + else: + s = torch.rand(5, dtype=torch.float64) + 0.1 + zp = torch.randint(5, 15, (5,)) + x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) + self.register_buffer('x', x_q) + + @torch.jit.script_method + def forward(self): + return self.x + + for per_channel in [False, True]: + model = SimpleQTensor(per_channel) + buffer = io.BytesIO() + torch.jit.save(model, buffer) + buffer.seek(0) + model_loaded = torch.jit.load(buffer) + self.assertEqual(model_loaded(), model()) + + @skipIfNoFBGEMM + def test_erase_class_tensor_shapes(self): + class Linear(torch.nn.Module): + def __init__(self, in_features, out_features): + super(Linear, self).__init__() + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, + dtype=torch.qint8) + self._packed_weight = torch.ops.quantized.linear_prepack(qweight) + + @torch.jit.export + def __getstate__(self): + return (torch.ops.quantized.linear_unpack(self._packed_weight)[0], self.training) + + def forward(self): + return self._packed_weight + + @torch.jit.export + def __setstate__(self, state): + self._packed_weight = torch.ops.quantized.linear_prepack(state[0]) + self.training = state[1] + + @property + def weight(self): + return torch.ops.quantized.linear_unpack(self._packed_weight)[0] + + @weight.setter + def weight(self, w): + self._packed_weight = torch.ops.quantized.linear_prepack(w) + + with torch._jit_internal._disable_emit_hooks(): + x = torch.jit.script(Linear(10, 10)) + torch._C._jit_pass_erase_shape_information(x.graph) + + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_quantization.py TESTNAME\n\n" + "instead.") diff --git a/test/quantization/test_equalize.py b/test/quantization/test_equalize.py index d0bd1334a46bb3..d2c67101bcac3c 100644 --- a/test/quantization/test_equalize.py +++ b/test/quantization/test_equalize.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.quantization.fuse_modules import fuse_modules import torch.quantization._equalize as _equalize @@ -93,3 +94,95 @@ def forward(self, x): input = torch.randn(20, 3) self.assertEqual(chain1(input), chain2(input)) + + def test_equalize_fused_convrelu(self): + ''' Checks to see if eager mode equalization supports fused + ConvReLU2d models + + A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu + layers are fused together and adjacent conv2d layers have cross-layer + equalization applied. Finally, we ensure that the channels have been + equalized and that the equalized and unequalized versions of the model + yield the same output given the same input + ''' + class M(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 3, 1).to(dtype=torch.float) + self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float) + self.conv2 = nn.Conv2d(3, 3, 1).to(dtype=torch.float) + self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) + self.conv3 = nn.Conv2d(3, 3, 1).to(dtype=torch.float) + self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.relu2(x) + x = self.conv3(x) + x = self.relu3(x) + return x + + model = M() + + fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']]) + fused_model2 = copy.deepcopy(fused_model1) + + _equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6) + conv1 = self.getModule(fused_model1, 'conv1')[0] + conv2 = self.getModule(fused_model1, 'conv2')[0] + conv3 = self.getModule(fused_model1, 'conv3')[0] + + self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1) + self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1) + + input = torch.randn(3, 3, 1, 1) + self.assertEqual(fused_model1(input), fused_model2(input)) + self.assertEqual(fused_model1(input), model(input)) + + def test_equalize_fused_linearrelu(self): + ''' Checks to see if eager mode equalization supports fused + LinearReLU models + + A model with 3 LinearReLU is constructed. Next, the linear and relu + layers are fused together and adjacent linear layers have cross-layer + equalization applied. Finally, we ensure that the channels have been + equalized and that the equalized and unequalized versions of the model + yield the same output given the same input + ''' + class M(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(3, 4) + self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float) + self.linear2 = nn.Linear(4, 5) + self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) + self.linear3 = nn.Linear(5, 6) + self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float) + + def forward(self, x): + x = self.linear1(x) + x = self.relu1(x) + x = self.linear2(x) + x = self.relu2(x) + x = self.linear3(x) + x = self.relu3(x) + return x + + model = M() + + fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']]) + fused_model2 = copy.deepcopy(fused_model1) + + _equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6) + linear1 = self.getModule(fused_model1, 'linear1')[0] + linear2 = self.getModule(fused_model1, 'linear2')[0] + linear3 = self.getModule(fused_model1, 'linear3')[0] + + self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1) + self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1) + + input = torch.randn(20, 3) + self.assertEqual(fused_model1(input), fused_model2(input)) + self.assertEqual(fused_model1(input), model(input)) diff --git a/test/quantization/test_fusion.py b/test/quantization/test_fusion.py new file mode 100644 index 00000000000000..81f362356f8a45 --- /dev/null +++ b/test/quantization/test_fusion.py @@ -0,0 +1,365 @@ +import torch +import torch.nn as nn +import torch.nn.quantized as nnq +import torch.nn.intrinsic as nni +import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.qat as nniqat +from torch.quantization import ( + quantize, + prepare, + convert, + prepare_qat, + quantize_qat, + fuse_modules, + QConfig, + default_qconfig, + default_qat_qconfig, +) + +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + ModelForFusion, + ModelWithSequentialFusion, + ModelForLinearBNFusion, + ModelForFusionWithBias, + test_only_eval_fn, + test_only_train_fn, + skipIfNoFBGEMM, +) + +from torch.testing._internal.common_quantized import ( + override_quantized_engine, + supported_qengines, +) + + +@skipIfNoFBGEMM +class TestFusion(QuantizationTestCase): + def test_fuse_module_train(self): + model = ModelForFusion(default_qat_qconfig).train() + # Test step by step fusion + model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) + model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) + self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, + msg="Fused Conv + BN + Relu first layer") + self.assertEqual(type(model.bn1), torch.nn.Identity, + msg="Fused Conv + BN + Relu (skipped BN)") + self.assertEqual(type(model.relu1), torch.nn.Identity, + msg="Fused Conv + BN + Relu (skipped Relu)") + + self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, + msg="Fused submodule Conv + BN") + self.assertEqual(type(model.sub1.bn), torch.nn.Identity, + msg="Fused submodule Conv + BN (skipped BN)") + self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, + msg="Non-fused submodule Conv") + self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, + msg="Non-fused submodule ReLU") + model = prepare_qat(model) + self.checkObservers(model) + + def checkQAT(model): + self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) + self.assertEqual(type(model.bn1), nn.Identity) + self.assertEqual(type(model.relu1), nn.Identity) + self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) + self.assertEqual(type(model.sub1.bn), nn.Identity) + self.assertEqual(type(model.sub2.conv), nn.Conv2d) + self.assertEqual(type(model.sub2.relu), nn.ReLU) + + checkQAT(model) + test_only_train_fn(model, self.img_data_1d_train) + model = convert(model) + + def checkQuantized(model): + self.assertEqual(type(model.conv1), nniq.ConvReLU2d) + self.assertEqual(type(model.bn1), nn.Identity) + self.assertEqual(type(model.relu1), nn.Identity) + self.assertEqual(type(model.sub1.conv), nnq.Conv2d) + self.assertEqual(type(model.sub1.bn), nn.Identity) + self.assertEqual(type(model.sub2.conv), nn.Conv2d) + self.assertEqual(type(model.sub2.relu), nn.ReLU) + test_only_eval_fn(model, self.img_data_1d) + self.checkNoQconfig(model) + + with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): + checkQuantized(model) + + model = ModelForFusion(default_qat_qconfig).train() + model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], + ['sub1.conv', 'sub1.bn']]) + model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) + with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): + checkQuantized(model) + + + def test_fuse_module_eval(self): + model = ModelForFusion(default_qconfig) + model.eval() + model = fuse_modules(model, [['conv3', 'bn3', 'relu4'], + ['conv1', 'bn1', 'relu1'], + ['conv2', 'relu2'], + ['bn2', 'relu3'], + ['sub1.conv', 'sub1.bn']]) + self.assertEqual(type(model.conv1), nni.ConvReLU2d, + msg="Fused Conv + BN + Relu first layer (BN is folded)") + self.assertEqual(type(model.conv1[0]), nn.Conv2d, + msg="Fused Conv + BN + Relu (Conv + folded BN only)") + self.assertEqual(type(model.conv1[1]), nn.ReLU, + msg="Fused Conv + BN + Relu second layer (Relu only)") + self.assertEqual(type(model.bn1), nn.Identity, + msg="Fused Conv + BN + Relu second layer (Skipped BN)") + self.assertEqual(type(model.relu1), nn.Identity, + msg="Fused Conv + BN + Relu second layer (Skipped Relu)") + self.assertEqual(type(model.conv2), nni.ConvReLU3d, + msg="Fused Conv + BN + Relu first layer (BN is folded)") + self.assertEqual(type(model.bn2), nni.BNReLU3d, + msg="Fused BN + Relu first layer (Relu is folded))") + self.assertEqual(type(model.relu3), nn.Identity, + msg="Fused BN + Relu second layer (Skipped Relu)") + self.assertEqual(type(model.conv2[0]), nn.Conv3d, + msg="Fused Conv + BN + Relu (Conv + folded BN only)") + self.assertEqual(type(model.conv2[1]), nn.ReLU, + msg="Fused Conv + BN + Relu second layer (Relu only)") + self.assertEqual(type(model.relu2), nn.Identity, + msg="Fused Conv + BN + Relu second layer (Skipped Relu)") + + self.assertEqual(type(model.conv3), nni.ConvReLU1d, + msg="Fused Conv + Relu for Conv1d (folded BN)") + self.assertEqual(type(model.conv3[0]), nn.Conv1d, + msg="Fused Conv + Relu for Conv1d ") + self.assertEqual(type(model.conv3[1]), nn.ReLU, + msg="Fused Conv + Relu for Conv1d") + self.assertEqual(type(model.bn3), nn.Identity, + msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)") + + self.assertEqual(type(model.sub1.conv), nn.Conv2d, + msg="Fused submodule Conv + folded BN") + self.assertEqual(type(model.sub1.bn), nn.Identity, + msg="Fused submodule (skipped BN)") + self.assertEqual(type(model.sub2.conv), nn.Conv2d, + msg="Non-fused submodule Conv") + self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, + msg="Non-fused submodule ReLU") + + model = prepare(model) + self.checkObservers(model) + test_only_eval_fn(model, self.img_data_1d) + model = convert(model) + + def checkQuantized(model): + self.assertEqual(type(model.conv3), nniq.ConvReLU1d) + self.assertEqual(type(model.conv1), nniq.ConvReLU2d) + self.assertEqual(type(model.bn1), nn.Identity) + self.assertEqual(type(model.relu1), nn.Identity) + self.assertEqual(type(model.sub1.conv), nnq.Conv2d) + self.assertEqual(type(model.sub1.bn), nn.Identity) + self.assertEqual(type(model.sub2.conv), nn.Conv2d) + self.assertEqual(type(model.sub2.relu), nn.ReLU) + self.assertEqual(type(model.bn2), nniq.BNReLU3d) + test_only_eval_fn(model, self.img_data_1d) + self.checkNoQconfig(model) + + checkQuantized(model) + + model = ModelForFusion(default_qconfig).eval() + model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], + ['conv2', 'relu2'], + ['bn2', 'relu3'], + ['sub1.conv', 'sub1.bn'], + ['conv3', 'bn3', 'relu4']]) + model = quantize(model, test_only_eval_fn, [self.img_data_1d]) + checkQuantized(model) + + def test_fusion_sequential_model_train(self): + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = ModelWithSequentialFusion().train() + model.to(torch.float) + fuse_modules(model, [['conv1', 'relu1'] , + ['features.0.0', 'features.0.1', 'features.0.2'], + ['features.1.0', 'features.1.1', 'features.1.2'], + ['features.2.0', 'features.2.1', 'features.2.2'], + ['classifier.0', 'classifier.1']], inplace=True) + self.assertEqual(type(model.conv1), nni.ConvReLU2d, + msg="Fused Conv + Relu: nni.ConvReLU2d") + self.assertEqual(type(model.conv1[0]), nn.Conv2d, + msg="Fused Conv + Relu: Conv2d") + self.assertEqual(type(model.conv1[1]), nn.ReLU, + msg="Fused Conv + Relu: Relu") + self.assertEqual(type(model.relu1), nn.Identity, + msg="Fused Conv + Relu: Identity") + for i in range(3): + self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, + msg="Fused submodule Conv + folded BN") + self.assertEqual(type(model.features[i][1]), nn.Identity, + msg="Fused submodule (skipped BN)") + self.assertEqual(type(model.features[i][2]), nn.Identity, + msg="Non-fused submodule Conv") + self.assertEqual(type(model.classifier[0]), nni.LinearReLU) + self.assertEqual(type(model.classifier[1]), nn.Identity) + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) + prepare_qat(model, inplace=True) + self.checkObservers(model) + model(self.img_data_2d[0][0]) + + + def checkQAT(model): + self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) + self.assertEqual(type(model.relu1), nn.Identity) + for i in range(3): + self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, + msg="Fused submodule Conv + folded BN") + self.assertEqual(type(model.features[i][1]), nn.Identity, + msg="Fused submodule (skipped BN)") + self.assertEqual(type(model.features[i][2]), nn.Identity, + msg="Non-fused submodule Conv") + self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) + self.assertEqual(type(model.classifier[1]), nn.Identity) + + checkQAT(model) + model(self.img_data_2d[1][0]) + convert(model, inplace=True) + model(self.img_data_2d[1][0]) + self.checkModelWithSequentialQuantized(model) + + def test_fusion_sequential_model_eval(self): + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = ModelWithSequentialFusion().eval() + model.to(torch.float) + fuse_modules(model, [['conv1', 'relu1'] , + ['features.0.0', 'features.0.1', 'features.0.2'], + ['features.1.0', 'features.1.1', 'features.1.2'], + ['features.2.0', 'features.2.1', 'features.2.2'], + ['classifier.0', 'classifier.1']], inplace=True) + self.assertEqual(type(model.conv1), nni.ConvReLU2d, + msg="Fused Conv + Relu: nni.ConvReLU2d") + self.assertEqual(type(model.conv1[0]), nn.Conv2d, + msg="Fused Conv + Relu: Conv2d") + self.assertEqual(type(model.conv1[1]), nn.ReLU, + msg="Fused Conv + Relu: Relu") + self.assertEqual(type(model.relu1), nn.Identity, + msg="Fused Conv + Relu: Identity") + for i in range(3): + self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, + msg="Fused submodule Conv + folded BN") + self.assertEqual(type(model.features[i][1]), nn.Identity, + msg="Fused submodule (skipped BN)") + self.assertEqual(type(model.features[i][2]), nn.Identity, + msg="Non-fused submodule Conv") + self.assertEqual(type(model.classifier[0]), nni.LinearReLU) + self.assertEqual(type(model.classifier[1]), nn.Identity) + model.qconfig = torch.quantization.get_default_qconfig(qengine) + prepare(model, inplace=True) + self.checkObservers(model) + model(self.img_data_2d[0][0]) + convert(model, inplace=True) + model(self.img_data_2d[1][0]) + self.checkModelWithSequentialQuantized(model) + + def checkModelWithSequentialQuantized(self, model): + self.assertEqual(type(model.conv1), nniq.ConvReLU2d) + self.assertEqual(type(model.relu1), nn.Identity) + for i in range(3): + self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d) + self.assertEqual(type(model.features[i][1]), nn.Identity) + self.assertEqual(type(model.features[i][2]), nn.Identity) + self.assertEqual(type(model.classifier[0]), nniq.LinearReLU) + self.assertEqual(type(model.classifier[1]), nn.Identity) + + def test_fusion_conv_with_bias(self): + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = ModelForFusionWithBias().train() + # output with no fusion. + out_ref = model(self.img_data_2d[0][0]) + + model.qconfig = QConfig(activation=torch.nn.Identity, + weight=torch.nn.Identity) + model = fuse_modules(model, [["conv1", "bn1", "relu1"], + ["conv2", "bn2"]]) + prep_model = prepare_qat(model, inplace=False) + # output with fusion but no observers. + out_fused = prep_model(self.img_data_2d[0][0]) + self.assertEqual(out_ref, out_fused) + + model.qconfig = torch.quantization.get_default_qconfig(qengine) + prepare_qat(model, inplace=True) + + model(self.img_data_2d[0][0]) + + def checkQAT(model): + self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) + self.assertEqual(type(model.bn1), nn.Identity) + self.assertEqual(type(model.relu1), nn.Identity) + self.assertEqual(type(model.conv2), nniqat.ConvBn2d) + self.assertEqual(type(model.bn2), nn.Identity) + + checkQAT(model) + + + def test_fusion_linear_bn_eval(self): + model = ModelForLinearBNFusion().train() + inp1 = torch.randn(8, 20) + inp2 = torch.randn(8, 20) + + # Get some interesting values into the running mean and variance. + model(inp1) + model.eval() + golden = model(inp2) + + model = fuse_modules(model, [["fc", "bn"]]) + self.assertEqual(type(model.bn), nn.Identity) + self.assertEqual(golden, model(inp2)) + + def test_forward_hooks_preserved(self): + r"""Test case that checks whether forward pre hooks of the first module and + post forward hooks of the last module in modules list passed to fusion function preserved. + (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)] + after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity]) + """ + model = ModelForFusion(default_qat_qconfig).train() + + counter = { + 'pre_forwards': 0, + 'forwards': 0, + } + fused = False + + def fw_pre_hook(fused_module_class, h_module, input): + if fused: + self.assertEqual(type(h_module), fused_module_class, + "After fusion owner of the first module's forward pre hook is not a fused module") + counter['pre_forwards'] += 1 + + def fw_hook(fused_module_class, h_module, input, output): + if fused: + self.assertEqual(type(h_module), fused_module_class, + "After fusion owner of the last module's forward hook is not a fused module") + counter['forwards'] += 1 + + # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference + model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)) + model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args)) + model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args)) + model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args)) + + test_only_eval_fn(model, self.img_data_1d) + self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) + self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) + + model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) + model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) + + fused = True + before_fusion_pre_count = counter['pre_forwards'] + before_fusion_post_count = counter['forwards'] + test_only_eval_fn(model, self.img_data_1d) + self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d)) + self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d)) + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_quantization.py TESTNAME\n\n" + "instead.") diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize_eager_ptq.py similarity index 56% rename from test/quantization/test_quantize.py rename to test/quantization/test_quantize_eager_ptq.py index 32b2e36636bcba..1df2dc5ce6edec 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize_eager_ptq.py @@ -1,24 +1,18 @@ + import torch import torch.nn as nn import torch.nn.quantized as nnq -import torch.nn.intrinsic as nni -import torch.nn.intrinsic.quantized as nniq -import torch.nn.intrinsic.qat as nniqat from torch.nn.utils.rnn import PackedSequence from torch.quantization import ( quantize, prepare, convert, prepare_qat, - quantize_qat, - fuse_modules, quantize_dynamic, QuantWrapper, QuantStub, DeQuantStub, - QConfig, default_qconfig, - default_qat_qconfig, default_dynamic_qconfig, per_channel_dynamic_qconfig, float16_dynamic_qconfig, @@ -26,18 +20,12 @@ PerChannelMinMaxObserver, QConfigDynamic, default_dynamic_quant_observer, - FixedQParamsFakeQuantize, ) from torch.testing._internal.common_quantization import ( QuantizationTestCase, AnnotatedSingleLayerLinearModel, QuantStubModel, - ModelForFusion, - ModelWithSequentialFusion, - ModelForLinearBNFusion, - ManualLinearQATModel, - ManualConvLinearQATModel, ModelWithFunctionals, ModelMultipleOps, ModelMultipleOpsNoAvgPool, @@ -47,11 +35,9 @@ ResNetBase, RNNDynamicModel, RNNCellDynamicModel, - ModelForFusionWithBias, ActivationsTestModel, NormalizationTestModel, test_only_eval_fn, - test_only_train_fn, prepare_dynamic, convert_dynamic, skipIfNoFBGEMM, @@ -74,8 +60,6 @@ supported_qengines, override_qengines, ) -from torch.testing._internal.common_utils import TemporaryFileName -from torch.testing._internal.common_utils import suppress_warnings from torch.testing._internal.jit_utils import JitTestCase from hypothesis import given from hypothesis import strategies as st @@ -84,7 +68,6 @@ # Standard library from typing import Tuple -import copy import io import unittest import numpy as np @@ -1150,208 +1133,7 @@ def checkHooksIsPresent(model): convert_dynamic(model) checkHooksIsPresent(model) - - -class TestQuantizationAwareTraining(QuantizationTestCase): - def test_manual(self): - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = ManualLinearQATModel(qengine) - model = prepare_qat(model) - self.checkObservers(model) - test_only_train_fn(model, self.train_data) - model = convert(model) - - def checkQuantized(model): - self.assertEqual(type(model.fc1), nnq.Linear) - self.assertEqual(type(model.fc2), nnq.Linear) - test_only_eval_fn(model, self.calib_data) - self.checkScriptable(model, self.calib_data) - self.checkNoQconfig(model) - - checkQuantized(model) - - model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, - [self.train_data]) - checkQuantized(model) - - def test_eval_only_fake_quant(self): - r"""Using FakeQuant in evaluation only mode, - this is useful for estimating accuracy loss when we quantize the - network - """ - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = ManualLinearQATModel(qengine) - - model = prepare_qat(model) - self.checkObservers(model) - - model.eval() - test_only_eval_fn(model, self.calib_data) - - def test_conv_linear(self): - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = ManualConvLinearQATModel() - - model = prepare_qat(model) - self.checkObservers(model) - - test_only_train_fn(model, self.img_data_2d_train) - model = convert(model) - - def checkQuantized(model): - self.assertEqual(type(model.conv), nnq.Conv2d) - self.assertEqual(type(model.fc1), nnq.Linear) - self.assertEqual(type(model.fc2), nnq.Linear) - test_only_eval_fn(model, self.img_data_2d) - self.checkScriptable(model, self.img_data_2d) - self.checkNoQconfig(model) - - checkQuantized(model) - - model = ManualConvLinearQATModel() - model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) - checkQuantized(model) - - def test_train_save_load_eval(self): - r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict - During eval, we first call prepare_qat and conver on the model and then load the state_dict - and compare results against original model - """ - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = TwoLayerLinearModel() - model = torch.quantization.QuantWrapper(model) - model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) - model = prepare_qat(model) - - fq_state_dict = model.state_dict() - - test_only_train_fn(model, self.train_data) - model = convert(model) - - quant_state_dict = model.state_dict() - - x = torch.rand(2, 5, dtype=torch.float) - ref = model(x) - - # Create model again for eval. Check result using quantized state_dict - model = TwoLayerLinearModel() - model = torch.quantization.QuantWrapper(model) - model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) - torch.quantization.prepare_qat(model, inplace=True) - new_state_dict = model.state_dict() - - # Check to make sure the model after prepare_qat has the same state_dict as original. - self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) - - torch.quantization.convert(model, inplace=True) - model.eval() - model.load_state_dict(quant_state_dict) - out = model(x) - self.assertEqual(ref, out) - - # Check model created using prepare has same state dict as quantized state_dict - model = TwoLayerLinearModel() - model.eval() - model = torch.quantization.QuantWrapper(model) - model.qconfig = torch.quantization.get_default_qconfig(qengine) - torch.quantization.prepare(model, inplace=True) - torch.quantization.convert(model, inplace=True) - self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) - model.eval() - model.load_state_dict(quant_state_dict) - out = model(x) - self.assertEqual(ref, out) - - @override_qengines - def test_forward_hooks_preserved(self): - r"""Test QAT on preserving pre forward and post forward hooks of original model - """ - qengine = torch.backends.quantized.engine - model = QuantStubModel() - counter = { - 'pre_forwards': 0, - 'forwards': 0, - } - - def fw_pre_hook(h_module, input): - counter['pre_forwards'] += 1 - - def fw_hook(h_module, input, output): - counter['forwards'] += 1 - - model.fc.register_forward_pre_hook(fw_pre_hook) - model.fc.register_forward_hook(fw_hook) - - model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) - model = prepare_qat(model) - - def checkHooksIsPresent(model, before_convert=True): - forward_hooks = 1 - if before_convert: - self.assertEqual(len(model.quant._forward_hooks.values()), 1, - "Quantization observer hook has disappeared") - forward_hooks = 2 - self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) - self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) - self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, - "Extra pre forward hooks have appeared on a layer") - self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks, - "Extra post forward hooks have appeared on a layer") - - checkHooksIsPresent(model, True) - x = torch.rand(2, 5, dtype=torch.float) - model(x) - torch.quantization.convert(model, inplace=True) - checkHooksIsPresent(model, False) - - def test_add_scalar_uses_input_qparams(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant = torch.quantization.QuantStub() - self.ff = torch.nn.quantized.FloatFunctional() - - def forward(self, x): - x = self.quant(x) - x = self.ff.add_scalar(x, 1.0) - return x - - m = M() - m.qconfig = torch.quantization.default_qconfig - mp = torch.quantization.prepare_qat(m) - mp(torch.randn(4, 4)) - mq = torch.quantization.convert(mp) - res = mq(torch.randn(4, 4)) - eps = 1e-5 - self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps) - - def test_mul_scalar_uses_input_qparams(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant = torch.quantization.QuantStub() - self.ff = torch.nn.quantized.FloatFunctional() - - def forward(self, x): - x = self.quant(x) - x = self.ff.mul_scalar(x, 2.0) - return x - - m = M() - m.qconfig = torch.quantization.default_qconfig - mp = torch.quantization.prepare_qat(m) - mp(torch.randn(4, 4)) - mq = torch.quantization.convert(mp) - res = mq(torch.randn(4, 4)) - eps = 1e-5 - self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) - - -class TestEagerModeOps(QuantizationTestCase): +class TestEagerModeActivationOps(QuantizationTestCase): def _test_activation_op_impl( self, float_module_class, quantized_module_class, extra_module_kwargs): """ Implementation for testing common activation ops like leaky relu @@ -1384,102 +1166,6 @@ def test_leaky_relu(self): def test_relu(self): self._test_activation_op_impl(nn.ReLU, nn.ReLU, {'inplace': False}) - -class TestEagerModeQATOps(QuantizationTestCase): - def _test_activation_convert_numerics_impl(self, Act, data): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.act = Act() - self.quant = QuantStub() - self.dequant = DeQuantStub() - - def forward(self, x): - x = self.quant(x) - x = self.act(x) - x = self.dequant(x) - return x - - m = M().train() - m.qconfig = default_qat_qconfig - m = prepare_qat(m) - before_convert = m(data) - m = convert(m) - after_convert = m(data) - self.assertEqual(before_convert, after_convert) - - def test_fixed_qparam_ops(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - self.hardsigmoid = torch.nn.Hardsigmoid() - self.tanh = torch.nn.Tanh() - self.quant = QuantStub() - self.dequant = DeQuantStub() - - def forward(self, x): - x = self.quant(x) - x = self.sigmoid(x) - x = self.hardsigmoid(x) - x = self.tanh(x) - x = self.dequant(x) - return x - - m = M().train() - m.qconfig = default_qat_qconfig - m = prepare_qat(m) - for attr in ['sigmoid', 'hardsigmoid', 'tanh']: - self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) - data = torch.randn(1, 3, 2, 4) - before_convert = m(data) - m = convert(m) - after_convert = m(data) - self.assertEqual(before_convert, after_convert) - # make sure activation post process is removed - for attr in ['sigmoid', 'hardsigmoid', 'tanh']: - # verify fake quant module is removd - self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) - # verify that hooks are removed - self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) - - # make sure no fake quantize module is inserted for eval mode - - def checkNoFQModule(m): - for attr in ['sigmoid', 'hardsigmoid', 'tanh']: - self.assertFalse(hasattr(getattr(m, attr), "activation_post_process")) - self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) - - m = M().eval() - m.qconfig = default_qconfig - m = prepare(m) - checkNoFQModule(m) - m = convert(m) - checkNoFQModule(m) - - def test_leaky_relu(self): - data = torch.randn(1, 3, 2, 4) - self._test_activation_convert_numerics_impl(nn.LeakyReLU, data) - - def test_relu(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = nn.ReLU() - - def forward(self, x): - x = self.relu(x) - return x - - m = M().train() - m.qconfig = default_qconfig - m = prepare_qat(m) - # make sure no activation_post_process is inserted for relu - self.assertFalse(hasattr(m, "activation_post_process")) - m = convert(m) - # make sure ReLU module is not changed - self.assertTrue(type(m.relu), nn.ReLU) - class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) @@ -1511,332 +1197,6 @@ def checkQuantized(model): checkQuantized(model) self.checkScriptable(model, [[xq]], check_save_load=True) -@skipIfNoFBGEMM -class TestFusion(QuantizationTestCase): - def test_fuse_module_train(self): - model = ModelForFusion(default_qat_qconfig).train() - # Test step by step fusion - model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) - model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) - self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, - msg="Fused Conv + BN + Relu first layer") - self.assertEqual(type(model.bn1), torch.nn.Identity, - msg="Fused Conv + BN + Relu (skipped BN)") - self.assertEqual(type(model.relu1), torch.nn.Identity, - msg="Fused Conv + BN + Relu (skipped Relu)") - - self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, - msg="Fused submodule Conv + BN") - self.assertEqual(type(model.sub1.bn), torch.nn.Identity, - msg="Fused submodule Conv + BN (skipped BN)") - self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, - msg="Non-fused submodule Conv") - self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, - msg="Non-fused submodule ReLU") - model = prepare_qat(model) - self.checkObservers(model) - - def checkQAT(model): - self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) - self.assertEqual(type(model.bn1), nn.Identity) - self.assertEqual(type(model.relu1), nn.Identity) - self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) - self.assertEqual(type(model.sub1.bn), nn.Identity) - self.assertEqual(type(model.sub2.conv), nn.Conv2d) - self.assertEqual(type(model.sub2.relu), nn.ReLU) - - checkQAT(model) - test_only_train_fn(model, self.img_data_1d_train) - model = convert(model) - - def checkQuantized(model): - self.assertEqual(type(model.conv1), nniq.ConvReLU2d) - self.assertEqual(type(model.bn1), nn.Identity) - self.assertEqual(type(model.relu1), nn.Identity) - self.assertEqual(type(model.sub1.conv), nnq.Conv2d) - self.assertEqual(type(model.sub1.bn), nn.Identity) - self.assertEqual(type(model.sub2.conv), nn.Conv2d) - self.assertEqual(type(model.sub2.relu), nn.ReLU) - test_only_eval_fn(model, self.img_data_1d) - self.checkNoQconfig(model) - - with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): - checkQuantized(model) - - model = ModelForFusion(default_qat_qconfig).train() - model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], - ['sub1.conv', 'sub1.bn']]) - model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) - with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): - checkQuantized(model) - - - def test_fuse_module_eval(self): - model = ModelForFusion(default_qconfig) - model.eval() - model = fuse_modules(model, [['conv3', 'bn3', 'relu4'], - ['conv1', 'bn1', 'relu1'], - ['conv2', 'relu2'], - ['bn2', 'relu3'], - ['sub1.conv', 'sub1.bn']]) - self.assertEqual(type(model.conv1), nni.ConvReLU2d, - msg="Fused Conv + BN + Relu first layer (BN is folded)") - self.assertEqual(type(model.conv1[0]), nn.Conv2d, - msg="Fused Conv + BN + Relu (Conv + folded BN only)") - self.assertEqual(type(model.conv1[1]), nn.ReLU, - msg="Fused Conv + BN + Relu second layer (Relu only)") - self.assertEqual(type(model.bn1), nn.Identity, - msg="Fused Conv + BN + Relu second layer (Skipped BN)") - self.assertEqual(type(model.relu1), nn.Identity, - msg="Fused Conv + BN + Relu second layer (Skipped Relu)") - self.assertEqual(type(model.conv2), nni.ConvReLU3d, - msg="Fused Conv + BN + Relu first layer (BN is folded)") - self.assertEqual(type(model.bn2), nni.BNReLU3d, - msg="Fused BN + Relu first layer (Relu is folded))") - self.assertEqual(type(model.relu3), nn.Identity, - msg="Fused BN + Relu second layer (Skipped Relu)") - self.assertEqual(type(model.conv2[0]), nn.Conv3d, - msg="Fused Conv + BN + Relu (Conv + folded BN only)") - self.assertEqual(type(model.conv2[1]), nn.ReLU, - msg="Fused Conv + BN + Relu second layer (Relu only)") - self.assertEqual(type(model.relu2), nn.Identity, - msg="Fused Conv + BN + Relu second layer (Skipped Relu)") - - self.assertEqual(type(model.conv3), nni.ConvReLU1d, - msg="Fused Conv + Relu for Conv1d (folded BN)") - self.assertEqual(type(model.conv3[0]), nn.Conv1d, - msg="Fused Conv + Relu for Conv1d ") - self.assertEqual(type(model.conv3[1]), nn.ReLU, - msg="Fused Conv + Relu for Conv1d") - self.assertEqual(type(model.bn3), nn.Identity, - msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)") - - self.assertEqual(type(model.sub1.conv), nn.Conv2d, - msg="Fused submodule Conv + folded BN") - self.assertEqual(type(model.sub1.bn), nn.Identity, - msg="Fused submodule (skipped BN)") - self.assertEqual(type(model.sub2.conv), nn.Conv2d, - msg="Non-fused submodule Conv") - self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, - msg="Non-fused submodule ReLU") - - model = prepare(model) - self.checkObservers(model) - test_only_eval_fn(model, self.img_data_1d) - model = convert(model) - - def checkQuantized(model): - self.assertEqual(type(model.conv3), nniq.ConvReLU1d) - self.assertEqual(type(model.conv1), nniq.ConvReLU2d) - self.assertEqual(type(model.bn1), nn.Identity) - self.assertEqual(type(model.relu1), nn.Identity) - self.assertEqual(type(model.sub1.conv), nnq.Conv2d) - self.assertEqual(type(model.sub1.bn), nn.Identity) - self.assertEqual(type(model.sub2.conv), nn.Conv2d) - self.assertEqual(type(model.sub2.relu), nn.ReLU) - self.assertEqual(type(model.bn2), nniq.BNReLU3d) - test_only_eval_fn(model, self.img_data_1d) - self.checkNoQconfig(model) - - checkQuantized(model) - - model = ModelForFusion(default_qconfig).eval() - model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], - ['conv2', 'relu2'], - ['bn2', 'relu3'], - ['sub1.conv', 'sub1.bn'], - ['conv3', 'bn3', 'relu4']]) - model = quantize(model, test_only_eval_fn, [self.img_data_1d]) - checkQuantized(model) - - def test_fusion_sequential_model_train(self): - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = ModelWithSequentialFusion().train() - model.to(torch.float) - fuse_modules(model, [['conv1', 'relu1'] , - ['features.0.0', 'features.0.1', 'features.0.2'], - ['features.1.0', 'features.1.1', 'features.1.2'], - ['features.2.0', 'features.2.1', 'features.2.2'], - ['classifier.0', 'classifier.1']], inplace=True) - self.assertEqual(type(model.conv1), nni.ConvReLU2d, - msg="Fused Conv + Relu: nni.ConvReLU2d") - self.assertEqual(type(model.conv1[0]), nn.Conv2d, - msg="Fused Conv + Relu: Conv2d") - self.assertEqual(type(model.conv1[1]), nn.ReLU, - msg="Fused Conv + Relu: Relu") - self.assertEqual(type(model.relu1), nn.Identity, - msg="Fused Conv + Relu: Identity") - for i in range(3): - self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, - msg="Fused submodule Conv + folded BN") - self.assertEqual(type(model.features[i][1]), nn.Identity, - msg="Fused submodule (skipped BN)") - self.assertEqual(type(model.features[i][2]), nn.Identity, - msg="Non-fused submodule Conv") - self.assertEqual(type(model.classifier[0]), nni.LinearReLU) - self.assertEqual(type(model.classifier[1]), nn.Identity) - model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) - prepare_qat(model, inplace=True) - self.checkObservers(model) - model(self.img_data_2d[0][0]) - - - def checkQAT(model): - self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) - self.assertEqual(type(model.relu1), nn.Identity) - for i in range(3): - self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, - msg="Fused submodule Conv + folded BN") - self.assertEqual(type(model.features[i][1]), nn.Identity, - msg="Fused submodule (skipped BN)") - self.assertEqual(type(model.features[i][2]), nn.Identity, - msg="Non-fused submodule Conv") - self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) - self.assertEqual(type(model.classifier[1]), nn.Identity) - - checkQAT(model) - model(self.img_data_2d[1][0]) - convert(model, inplace=True) - model(self.img_data_2d[1][0]) - self.checkModelWithSequentialQuantized(model) - - def test_fusion_sequential_model_eval(self): - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = ModelWithSequentialFusion().eval() - model.to(torch.float) - fuse_modules(model, [['conv1', 'relu1'] , - ['features.0.0', 'features.0.1', 'features.0.2'], - ['features.1.0', 'features.1.1', 'features.1.2'], - ['features.2.0', 'features.2.1', 'features.2.2'], - ['classifier.0', 'classifier.1']], inplace=True) - self.assertEqual(type(model.conv1), nni.ConvReLU2d, - msg="Fused Conv + Relu: nni.ConvReLU2d") - self.assertEqual(type(model.conv1[0]), nn.Conv2d, - msg="Fused Conv + Relu: Conv2d") - self.assertEqual(type(model.conv1[1]), nn.ReLU, - msg="Fused Conv + Relu: Relu") - self.assertEqual(type(model.relu1), nn.Identity, - msg="Fused Conv + Relu: Identity") - for i in range(3): - self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, - msg="Fused submodule Conv + folded BN") - self.assertEqual(type(model.features[i][1]), nn.Identity, - msg="Fused submodule (skipped BN)") - self.assertEqual(type(model.features[i][2]), nn.Identity, - msg="Non-fused submodule Conv") - self.assertEqual(type(model.classifier[0]), nni.LinearReLU) - self.assertEqual(type(model.classifier[1]), nn.Identity) - model.qconfig = torch.quantization.get_default_qconfig(qengine) - prepare(model, inplace=True) - self.checkObservers(model) - model(self.img_data_2d[0][0]) - convert(model, inplace=True) - model(self.img_data_2d[1][0]) - self.checkModelWithSequentialQuantized(model) - - def checkModelWithSequentialQuantized(self, model): - self.assertEqual(type(model.conv1), nniq.ConvReLU2d) - self.assertEqual(type(model.relu1), nn.Identity) - for i in range(3): - self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d) - self.assertEqual(type(model.features[i][1]), nn.Identity) - self.assertEqual(type(model.features[i][2]), nn.Identity) - self.assertEqual(type(model.classifier[0]), nniq.LinearReLU) - self.assertEqual(type(model.classifier[1]), nn.Identity) - - def test_fusion_conv_with_bias(self): - for qengine in supported_qengines: - with override_quantized_engine(qengine): - model = ModelForFusionWithBias().train() - # output with no fusion. - out_ref = model(self.img_data_2d[0][0]) - - model.qconfig = QConfig(activation=torch.nn.Identity, - weight=torch.nn.Identity) - model = fuse_modules(model, [["conv1", "bn1", "relu1"], - ["conv2", "bn2"]]) - prep_model = prepare_qat(model, inplace=False) - # output with fusion but no observers. - out_fused = prep_model(self.img_data_2d[0][0]) - self.assertEqual(out_ref, out_fused) - - model.qconfig = torch.quantization.get_default_qconfig(qengine) - prepare_qat(model, inplace=True) - - model(self.img_data_2d[0][0]) - - def checkQAT(model): - self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) - self.assertEqual(type(model.bn1), nn.Identity) - self.assertEqual(type(model.relu1), nn.Identity) - self.assertEqual(type(model.conv2), nniqat.ConvBn2d) - self.assertEqual(type(model.bn2), nn.Identity) - - checkQAT(model) - - - def test_fusion_linear_bn_eval(self): - model = ModelForLinearBNFusion().train() - inp1 = torch.randn(8, 20) - inp2 = torch.randn(8, 20) - - # Get some interesting values into the running mean and variance. - model(inp1) - model.eval() - golden = model(inp2) - - model = fuse_modules(model, [["fc", "bn"]]) - self.assertEqual(type(model.bn), nn.Identity) - self.assertEqual(golden, model(inp2)) - - def test_forward_hooks_preserved(self): - r"""Test case that checks whether forward pre hooks of the first module and - post forward hooks of the last module in modules list passed to fusion function preserved. - (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)] - after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity]) - """ - model = ModelForFusion(default_qat_qconfig).train() - - counter = { - 'pre_forwards': 0, - 'forwards': 0, - } - fused = False - - def fw_pre_hook(fused_module_class, h_module, input): - if fused: - self.assertEqual(type(h_module), fused_module_class, - "After fusion owner of the first module's forward pre hook is not a fused module") - counter['pre_forwards'] += 1 - - def fw_hook(fused_module_class, h_module, input, output): - if fused: - self.assertEqual(type(h_module), fused_module_class, - "After fusion owner of the last module's forward hook is not a fused module") - counter['forwards'] += 1 - - # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference - model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)) - model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args)) - model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args)) - model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args)) - - test_only_eval_fn(model, self.img_data_1d) - self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) - self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) - - model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) - model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) - - fused = True - before_fusion_pre_count = counter['pre_forwards'] - before_fusion_post_count = counter['forwards'] - test_only_eval_fn(model, self.img_data_1d) - self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d)) - self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d)) - class TestModelNumerics(QuantizationTestCase): def test_float_quant_compare_per_tensor(self): for qengine in supported_qengines: @@ -1989,343 +1349,6 @@ def test_lower_graph_conv3d(self): self._test_lower_graph_impl(model, data) -class TestDeprecatedJitQuantized(JitTestCase): - @skipIfNoFBGEMM - def test_rnn_cell_quantized(self): - d_in, d_hid = 2, 2 - - for cell in [ - torch.nn.LSTMCell(d_in, d_hid).float(), - torch.nn.GRUCell(d_in, d_hid).float(), - torch.nn.RNNCell(d_in, d_hid).float(), - ]: - if isinstance(cell, torch.nn.LSTMCell): - num_chunks = 4 - elif isinstance(cell, torch.nn.GRUCell): - num_chunks = 3 - elif isinstance(cell, torch.nn.RNNCell): - num_chunks = 1 - - # Replace parameter values s.t. the range of values is exactly - # 255, thus we will have 0 quantization error in the quantized - # GEMM call. This i s for testing purposes. - # - # Note that the current implementation does not support - # accumulation values outside of the range representable by a - # 16 bit integer, instead resulting in a saturated value. We - # must take care that in our test we do not end up with a dot - # product that overflows the int16 range, e.g. - # (255*127+255*127) = 64770. So, we hardcode the test values - # here and ensure a mix of signedness. - vals = [[100, -155], - [100, -155], - [-155, 100], - [-155, 100], - [100, -155], - [-155, 100], - [-155, 100], - [100, -155]] - vals = vals[:d_hid * num_chunks] - cell.weight_ih = torch.nn.Parameter( - torch.tensor(vals, dtype=torch.float), - requires_grad=False) - cell.weight_hh = torch.nn.Parameter( - torch.tensor(vals, dtype=torch.float), - requires_grad=False) - - ref = copy.deepcopy(cell) - - cell = torch.jit.quantized.quantize_rnn_cell_modules(cell) - x = torch.tensor([[100, -155], - [-155, 100], - [100, -155]], dtype=torch.float) - h0_vals = [[-155, 100], - [-155, 155], - [100, -155]] - hx = torch.tensor(h0_vals, dtype=torch.float) - if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): - cx = torch.tensor(h0_vals, dtype=torch.float) - hiddens = (hx, cx) - else: - hiddens = hx - - if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): - class ScriptWrapper(torch.jit.ScriptModule): - def __init__(self, cell): - super(ScriptWrapper, self).__init__() - self.cell = cell - - @torch.jit.script_method - def forward(self, x: torch.Tensor, - hiddens: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - return self.cell(x, hiddens) - else: - - class ScriptWrapper(torch.jit.ScriptModule): - def __init__(self, cell): - super(ScriptWrapper, self).__init__() - self.cell = cell - - @torch.jit.script_method - def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor: - return self.cell(x, hiddens) - - cell = ScriptWrapper(cell) - outs = cell(x, hiddens) - cell = self.getExportImportCopyWithPacking(cell) - - outs = cell(x, hiddens) - ref_outs = ref(x, hiddens) - - self.assertEqual(len(outs), len(ref_outs)) - for out, ref_out in zip(outs, ref_outs): - torch.testing.assert_allclose(out, ref_out) - - @skipIfNoFBGEMM - def test_rnn_quantized(self): - d_in, d_hid = 2, 2 - - for cell in [ - torch.nn.LSTM(d_in, d_hid).float(), - torch.nn.GRU(d_in, d_hid).float(), - ]: - - # Replace parameter values s.t. the range of values is exactly - # 255, thus we will have 0 quantization error in the quantized - # GEMM call. This i s for testing purposes. - # - # Note that the current implementation does not support - # accumulation values outside of the range representable by a - # 16 bit integer, instead resulting in a saturated value. We - # must take care that in our test we do not end up with a dot - # product that overflows the int16 range, e.g. - # (255*127+255*127) = 64770. So, we hardcode the test values - # here and ensure a mix of signedness. - vals = [[100, -155], - [100, -155], - [-155, 100], - [-155, 100], - [100, -155], - [-155, 100], - [-155, 100], - [100, -155]] - if isinstance(cell, torch.nn.LSTM): - num_chunks = 4 - elif isinstance(cell, torch.nn.GRU): - num_chunks = 3 - vals = vals[:d_hid * num_chunks] - cell.weight_ih_l0 = torch.nn.Parameter( - torch.tensor(vals, dtype=torch.float), - requires_grad=False) - cell.weight_hh_l0 = torch.nn.Parameter( - torch.tensor(vals, dtype=torch.float), - requires_grad=False) - - ref = copy.deepcopy(cell) - cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8) - cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16) - - niter = 10 - x = torch.tensor([[100, -155], - [-155, 100], - [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) - h0_vals = [[-155, 100], - [-155, 155], - [100, -155]] - hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) - cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) - - if isinstance(ref, torch.nn.LSTM): - hiddens = (hx, cx) - elif isinstance(ref, torch.nn.GRU): - hiddens = hx - - ref_out, ref_hid = ref(x, hiddens) - - # Compare int8 quantized to unquantized - output_int8, final_hiddens_int8 = cell_int8(x, hiddens) - - torch.testing.assert_allclose(output_int8, ref_out) - for out, ref in zip(final_hiddens_int8, ref_hid): - torch.testing.assert_allclose(out, ref) - - # Compare fp16 quantized to unquantized - output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) - - torch.testing.assert_allclose(output_fp16, ref_out) - for out, ref in zip(final_hiddens_fp16, ref_hid): - torch.testing.assert_allclose(out, ref) - - def compare_quantized_unquantized(ScriptWrapper, cell): - wrapper = ScriptWrapper(cell) - - # Compare quantize scripted module to unquantized - script_out, script_hid = wrapper(x, hiddens) - torch.testing.assert_allclose(script_out, ref_out) - for out, ref in zip(script_hid, ref_hid): - torch.testing.assert_allclose(out, ref) - - # Compare export/import to unquantized - export_import_wrapper = self.getExportImportCopyWithPacking(wrapper) - ei_out, ei_hid = export_import_wrapper(x, hiddens) - torch.testing.assert_allclose(ei_out, ref_out) - for out, ref in zip(ei_hid, ref_hid): - torch.testing.assert_allclose(out, ref) - - if isinstance(cell, torch.jit.quantized.QuantizedGRU): - class ScriptWrapper(torch.jit.ScriptModule): - def __init__(self, cell): - super(ScriptWrapper, self).__init__() - self.cell = cell - - @torch.jit.script_method - def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - return self.cell(x, hiddens) - - compare_quantized_unquantized(ScriptWrapper, cell) - elif isinstance(cell, torch.jit.quantized.QuantizedLSTM): - for cell in [cell_int8, cell_fp16]: - class ScriptWrapper(torch.jit.ScriptModule): - def __init__(self, cell): - super(ScriptWrapper, self).__init__() - self.cell = cell - - @torch.jit.script_method - def forward(self, x, hiddens): - # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) - # -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - return self.cell(x, hiddens) - compare_quantized_unquantized(ScriptWrapper, cell) - - if 'fbgemm' in torch.backends.quantized.supported_engines: - # Suppression: using deprecated quant api - @suppress_warnings - def test_quantization_modules(self): - K1, N1 = 2, 2 - - class FooBar(torch.nn.Module): - def __init__(self): - super(FooBar, self).__init__() - self.linear1 = torch.nn.Linear(K1, N1).float() - - def forward(self, x): - x = self.linear1(x) - return x - - fb = FooBar() - fb.linear1.weight = torch.nn.Parameter( - torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False) - fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False) - - x = (torch.rand(1, K1).float() - 0.5) / 10.0 - value = torch.tensor([[100, -150]], dtype=torch.float) - - y_ref = fb(value) - - fb_int8 = torch.jit.quantized.quantize_linear_modules(fb) - traced_int8 = torch.jit.trace(fb_int8, (x,)) - fb_int8 = self.getExportImportCopyWithPacking(traced_int8) - y_int8 = fb_int8(value) - - fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16) - traced_fp16 = torch.jit.trace(fb_fp16, (x,)) - fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16) - y_fp16 = fb_fp16(value) - - torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3) - torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3) - - def _test_pickle_checkpoint_qtensor(self, device): - with TemporaryFileName() as fname: - class M(torch.jit.ScriptModule): - __constants__ = ['fname'] - - def __init__(self): - super(M, self).__init__() - self.fname = fname - - @torch.jit.script_method - def forward(self, x, y): - torch.save((x, y), self.fname) - return y - - q = torch.quantize_per_tensor( - torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device) - qc = torch.quantize_per_channel( - torch.rand(2, 3, dtype=torch.float), - scales=torch.tensor([0.1, 0.5, 0.01]), - zero_points=torch.tensor([10, 0, 20]), - axis=1, dtype=torch.quint8).to(device) - m = M() - m(q, qc) - with open(fname, "rb") as handle: - loaded_q, loaded_qc = torch.load(fname) - self.assertEqual(loaded_q, q) - self.assertEqual(loaded_qc, qc) - - def test_pickle_checkpoint_qtensor(self): - self._test_pickle_checkpoint_qtensor('cpu') - - def test_serialize_qtensor(self): - class SimpleQTensor(torch.jit.ScriptModule): - def __init__(self, per_channel): - super(SimpleQTensor, self).__init__() - x = torch.rand(5, 5).float() - if not per_channel: - x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8) - else: - s = torch.rand(5, dtype=torch.float64) + 0.1 - zp = torch.randint(5, 15, (5,)) - x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) - self.register_buffer('x', x_q) - - @torch.jit.script_method - def forward(self): - return self.x - - for per_channel in [False, True]: - model = SimpleQTensor(per_channel) - buffer = io.BytesIO() - torch.jit.save(model, buffer) - buffer.seek(0) - model_loaded = torch.jit.load(buffer) - self.assertEqual(model_loaded(), model()) - - @skipIfNoFBGEMM - def test_erase_class_tensor_shapes(self): - class Linear(torch.nn.Module): - def __init__(self, in_features, out_features): - super(Linear, self).__init__() - qweight = torch._empty_affine_quantized( - [out_features, in_features], scale=1, zero_point=0, - dtype=torch.qint8) - self._packed_weight = torch.ops.quantized.linear_prepack(qweight) - - @torch.jit.export - def __getstate__(self): - return (torch.ops.quantized.linear_unpack(self._packed_weight)[0], self.training) - - def forward(self): - return self._packed_weight - - @torch.jit.export - def __setstate__(self, state): - self._packed_weight = torch.ops.quantized.linear_prepack(state[0]) - self.training = state[1] - - @property - def weight(self): - return torch.ops.quantized.linear_unpack(self._packed_weight)[0] - - @weight.setter - def weight(self, w): - self._packed_weight = torch.ops.quantized.linear_prepack(w) - - with torch._jit_internal._disable_emit_hooks(): - x = torch.jit.script(Linear(10, 10)) - torch._C._jit_pass_erase_shape_information(x.graph) - if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" diff --git a/test/quantization/test_qat_module.py b/test/quantization/test_quantize_eager_qat.py similarity index 63% rename from test/quantization/test_qat_module.py rename to test/quantization/test_quantize_eager_qat.py index 4a7832773403c7..ed15a25dd615a8 100644 --- a/test/quantization/test_qat_module.py +++ b/test/quantization/test_quantize_eager_qat.py @@ -1,18 +1,339 @@ import math import torch import torch.nn as nn +import torch.backends.mkldnn from torch.nn import Conv2d, BatchNorm2d, ReLU, init from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.nn.modules.utils import _pair -from torch.quantization.qconfig import default_qat_qconfig -import torch.backends.mkldnn +import torch.nn.quantized as nnq +from torch.quantization import ( + prepare, + convert, + prepare_qat, + quantize_qat, + QuantStub, + DeQuantStub, + default_qconfig, + default_qat_qconfig, + FixedQParamsFakeQuantize, +) from torch.testing._internal.common_utils import TestCase + +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + QuantStubModel, + ManualLinearQATModel, + ManualConvLinearQATModel, + TwoLayerLinearModel, + test_only_eval_fn, + test_only_train_fn, +) + +from torch.testing._internal.common_quantized import ( + override_quantized_engine, + supported_qengines, + override_qengines, +) + from hypothesis import given from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() from functools import reduce +class TestQuantizationAwareTraining(QuantizationTestCase): + def test_manual(self): + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = ManualLinearQATModel(qengine) + model = prepare_qat(model) + self.checkObservers(model) + test_only_train_fn(model, self.train_data) + model = convert(model) + + def checkQuantized(model): + self.assertEqual(type(model.fc1), nnq.Linear) + self.assertEqual(type(model.fc2), nnq.Linear) + test_only_eval_fn(model, self.calib_data) + self.checkScriptable(model, self.calib_data) + self.checkNoQconfig(model) + + checkQuantized(model) + + model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, + [self.train_data]) + checkQuantized(model) + + def test_eval_only_fake_quant(self): + r"""Using FakeQuant in evaluation only mode, + this is useful for estimating accuracy loss when we quantize the + network + """ + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = ManualLinearQATModel(qengine) + + model = prepare_qat(model) + self.checkObservers(model) + + model.eval() + test_only_eval_fn(model, self.calib_data) + + def test_conv_linear(self): + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = ManualConvLinearQATModel() + + model = prepare_qat(model) + self.checkObservers(model) + + test_only_train_fn(model, self.img_data_2d_train) + model = convert(model) + + def checkQuantized(model): + self.assertEqual(type(model.conv), nnq.Conv2d) + self.assertEqual(type(model.fc1), nnq.Linear) + self.assertEqual(type(model.fc2), nnq.Linear) + test_only_eval_fn(model, self.img_data_2d) + self.checkScriptable(model, self.img_data_2d) + self.checkNoQconfig(model) + + checkQuantized(model) + + model = ManualConvLinearQATModel() + model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) + checkQuantized(model) + + def test_train_save_load_eval(self): + r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict + During eval, we first call prepare_qat and conver on the model and then load the state_dict + and compare results against original model + """ + for qengine in supported_qengines: + with override_quantized_engine(qengine): + model = TwoLayerLinearModel() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) + model = prepare_qat(model) + + fq_state_dict = model.state_dict() + + test_only_train_fn(model, self.train_data) + model = convert(model) + + quant_state_dict = model.state_dict() + + x = torch.rand(2, 5, dtype=torch.float) + ref = model(x) + + # Create model again for eval. Check result using quantized state_dict + model = TwoLayerLinearModel() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) + torch.quantization.prepare_qat(model, inplace=True) + new_state_dict = model.state_dict() + + # Check to make sure the model after prepare_qat has the same state_dict as original. + self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) + + torch.quantization.convert(model, inplace=True) + model.eval() + model.load_state_dict(quant_state_dict) + out = model(x) + self.assertEqual(ref, out) + + # Check model created using prepare has same state dict as quantized state_dict + model = TwoLayerLinearModel() + model.eval() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qconfig(qengine) + torch.quantization.prepare(model, inplace=True) + torch.quantization.convert(model, inplace=True) + self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) + model.eval() + model.load_state_dict(quant_state_dict) + out = model(x) + self.assertEqual(ref, out) + + @override_qengines + def test_forward_hooks_preserved(self): + r"""Test QAT on preserving pre forward and post forward hooks of original model + """ + qengine = torch.backends.quantized.engine + model = QuantStubModel() + counter = { + 'pre_forwards': 0, + 'forwards': 0, + } + + def fw_pre_hook(h_module, input): + counter['pre_forwards'] += 1 + + def fw_hook(h_module, input, output): + counter['forwards'] += 1 + + model.fc.register_forward_pre_hook(fw_pre_hook) + model.fc.register_forward_hook(fw_hook) + + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) + model = prepare_qat(model) + + def checkHooksIsPresent(model, before_convert=True): + forward_hooks = 1 + if before_convert: + self.assertEqual(len(model.quant._forward_hooks.values()), 1, + "Quantization observer hook has disappeared") + forward_hooks = 2 + self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) + self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) + self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, + "Extra pre forward hooks have appeared on a layer") + self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks, + "Extra post forward hooks have appeared on a layer") + + checkHooksIsPresent(model, True) + x = torch.rand(2, 5, dtype=torch.float) + model(x) + torch.quantization.convert(model, inplace=True) + checkHooksIsPresent(model, False) + + def test_add_scalar_uses_input_qparams(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.ff = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + x = self.quant(x) + x = self.ff.add_scalar(x, 1.0) + return x + + m = M() + m.qconfig = torch.quantization.default_qconfig + mp = torch.quantization.prepare_qat(m) + mp(torch.randn(4, 4)) + mq = torch.quantization.convert(mp) + res = mq(torch.randn(4, 4)) + eps = 1e-5 + self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps) + + def test_mul_scalar_uses_input_qparams(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.ff = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + x = self.quant(x) + x = self.ff.mul_scalar(x, 2.0) + return x + + m = M() + m.qconfig = torch.quantization.default_qconfig + mp = torch.quantization.prepare_qat(m) + mp(torch.randn(4, 4)) + mq = torch.quantization.convert(mp) + res = mq(torch.randn(4, 4)) + eps = 1e-5 + self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) + +class TestQATActivationOps(QuantizationTestCase): + def _test_activation_convert_numerics_impl(self, Act, data): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.act = Act() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.act(x) + x = self.dequant(x) + return x + + m = M().train() + m.qconfig = default_qat_qconfig + m = prepare_qat(m) + before_convert = m(data) + m = convert(m) + after_convert = m(data) + self.assertEqual(before_convert, after_convert) + + def test_fixed_qparam_ops(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + self.hardsigmoid = torch.nn.Hardsigmoid() + self.tanh = torch.nn.Tanh() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.sigmoid(x) + x = self.hardsigmoid(x) + x = self.tanh(x) + x = self.dequant(x) + return x + + m = M().train() + m.qconfig = default_qat_qconfig + m = prepare_qat(m) + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) + data = torch.randn(1, 3, 2, 4) + before_convert = m(data) + m = convert(m) + after_convert = m(data) + self.assertEqual(before_convert, after_convert) + # make sure activation post process is removed + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + # verify fake quant module is removd + self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) + # verify that hooks are removed + self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) + + # make sure no fake quantize module is inserted for eval mode + + def checkNoFQModule(m): + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + self.assertFalse(hasattr(getattr(m, attr), "activation_post_process")) + self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) + + m = M().eval() + m.qconfig = default_qconfig + m = prepare(m) + checkNoFQModule(m) + m = convert(m) + checkNoFQModule(m) + + def test_leaky_relu(self): + data = torch.randn(1, 3, 2, 4) + self._test_activation_convert_numerics_impl(nn.LeakyReLU, data) + + def test_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return x + + m = M().train() + m.qconfig = default_qconfig + m = prepare_qat(m) + # make sure no activation_post_process is inserted for relu + self.assertFalse(hasattr(m, "activation_post_process")) + m = convert(m) + # make sure ReLU module is not changed + self.assertTrue(type(m.relu), nn.ReLU) + class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd): """ @@ -203,7 +524,7 @@ def __init__(self, padding, dilation, False, _pair(0), groups, bias, padding_mode, eps, momentum, freeze_bn, qconfig) -class TestQATModule(TestCase): +class TestConvBNQATModule(TestCase): @given(batch_size=st.integers(2, 4), input_channels_per_group=st.sampled_from([2, 3, 4]), @@ -506,3 +827,9 @@ def test_conv_bn_folded_vs_unfolded( qat_op_optim.step() qat_ref_op_optim.step() + + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_quantization.py TESTNAME\n\n" + "instead.") diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 3d3dcec0d8e0ea..51e5bfe7cd18ee 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -76,6 +76,8 @@ from torch.testing._internal.common_quantization import NodeSpec as ns +from torch.testing._internal.common_quantization import ConvModel + from torch.testing import FileCheck import copy @@ -265,6 +267,24 @@ def forward(self, x): ] self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) + def test_fuse_custom_config_dict_validity(self): + r""" + Verifies that if a user passes an invalid key or makes a typo when + constructing a fuse_custom_config_dict, an error will be thrown and + users will be notified of what keys are supported. + """ + m = ConvModel().eval() + from torch.quantization.quantize_fx import fuse_fx + fuse_custom_config_dict = {"typo": None} + + with self.assertRaises(ValueError) as context: + m = fuse_fx(m, fuse_custom_config_dict=fuse_custom_config_dict) + self.assertTrue( + 'Expected fuse_custom_config_dict to have the following keys:' + in str(context.exception) + ) + self.assertTrue('But found \'typo\' instead.' in str(context.exception)) + @skipIfNoFBGEMM class TestQuantizeFx(QuantizationTestCase): def test_pattern_match(self): @@ -1013,6 +1033,59 @@ def forward(self, x): self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig) self.assertEqual(m.module_conv2.qconfig, module_name_qconfig) + def test_qconfig_dict_validity(self): + r""" + Verifies that if a user passes an invalid key or makes a typo when + constructing a qconfig_dict, an error will be thrown and users will be + notified of what keys are supported. + """ + m = ConvModel().eval() + qconfig_dict = {"object_typo": [(torch.nn.Conv2d, default_qconfig)]} + + with self.assertRaises(ValueError) as context: + m = prepare_fx(m, qconfig_dict) + self.assertTrue( + 'Expected qconfig_dict to have the following keys:' in str(context.exception) + ) + self.assertTrue('But found \'object_typo\' instead.' in str(context.exception)) + + def test_prepare_custom_config_dict_validity(self): + r""" + Verifies that if a user passes an invalid key or makes a typo when + constructing a prepare_custom_config_dict, an error will be thrown and + users will be notified of what keys are supported. + """ + m = ConvModel().eval() + qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} + prepare_custom_config_dict = {"typo": None} + + with self.assertRaises(ValueError) as context: + m = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) + self.assertTrue( + 'Expected prepare_custom_config_dict to have the following keys:' + in str(context.exception) + ) + self.assertTrue('But found \'typo\' instead.' in str(context.exception)) + + def test_convert_custom_config_dict_validity(self): + r""" + Verifies that if a user passes an invalid key or makes a typo when + constructing a convert_custom_config_dict, an error will be thrown and + users will be notified of what keys are supported. + """ + m = ConvModel().eval() + qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} + m = prepare_fx(m, qconfig_dict) + convert_custom_config_dict = {"typo": None} + + with self.assertRaises(ValueError) as context: + m = convert_fx(m, convert_custom_config_dict=convert_custom_config_dict) + self.assertTrue( + 'Expected convert_custom_config_dict to have the following keys:' + in str(context.exception) + ) + self.assertTrue('But found \'typo\' instead.' in str(context.exception)) + def test_remove_qconfig(self): class M(torch.nn.Module): def __init__(self): @@ -2484,15 +2557,12 @@ def forward(self, x): ns.call_module(torch.quantization.PlaceholderObserver): 3 + int(use_bias) } convert_node_occurrence = { - # we don't support static fp16 ops, so the linear functino + # we don't support static fp16 ops, so the linear function # is unfused linear_fun: 1, # activation, weight, bias and output ns.call_method("to"): 3 + int(use_bias), - # TODO: because CopyNode is not handled properly currently, there is - # a dequantize that is missing, will need to fix and - # remove (- int(not has relu)) - ns.call_method("dequantize"): 3 + int(use_bias) - int(not has_relu) + ns.call_method("dequantize"): 3 + int(use_bias) } self.checkGraphModeFxOp( model, data, QuantType.DYNAMIC, linear_fun, @@ -3288,6 +3358,24 @@ def test_silu_reference(self): self._test_default_node_quant_handler_ops( module, functional, qconfig, is_reference, node_list) + def test_mish_reference(self): + module = torch.nn.Mish + functional = torch.nn.functional.mish + qconfig = float16_static_qconfig + is_reference = True + node_list = [ + ns.call_method("to"), + ns.call_method("dequantize"), + ns.call_module(module), + ns.call_method("to"), + ns.call_method('dequantize'), + ns.call_function(functional), + ns.call_method("to"), + ns.call_method('dequantize') + ] + self._test_default_node_quant_handler_ops( + module, functional, qconfig, is_reference, node_list) + def test_bmm_int_reference(self): class M(torch.nn.Module): def __init__(self): @@ -3960,16 +4048,17 @@ def forward(self, x): expected_occurrence = { # input and weight of first and second linear, output of first and second linear ns.call_module(torch.quantization.MinMaxObserver): 6, + ns.call_module(torch.quantization.PlaceholderObserver): 1 } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence ) - # make sure it runs m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, - ns.call_method("dequantize"): 2, + # dequantize after first linear, before reshape and before output + ns.call_method("dequantize"): 3, ns.call_method("to"): 1, ns.call_function(torch.ops.quantized.linear): 2 } @@ -3977,6 +4066,8 @@ def forward(self, x): m, expected_node_occurrence=expected_occurrence ) + # make sure it runs + m(torch.randn(2, 4)) def test_multiple_qconfigs_for_single_value(self): """ Test multiple qconfigs for a single value""" diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index 9db4b99dfe0822..f5946d891331a8 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -349,7 +349,10 @@ def __init__(self, weight, bias): self.bias = bias def forward(self, x): - return F.linear(x, self.weight, self.bias) + res = torch.matmul(x, self.weight.t()) + if self.bias is not None: + res.add_(self.bias) + return res x1 = torch.rand(3) w1 = torch.rand(5, 3) @@ -367,12 +370,19 @@ def forward(self, x): ): bias = b if has_bias else None model = torch.jit.trace(FunctionalLinear(weight, bias), [x]) + for node in model.graph.nodes(): + if node.kind() == "aten::matmul": + source_range_1 = node.sourceRange() torch._C._jit_pass_fuse_linear(model.graph) + for node in model.graph.nodes(): + if node.kind() == "aten::linear": + source_range_2 = node.sourceRange() FileCheck().check("aten::linear").run(model.graph) check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("] for cn in check_not: FileCheck().check_not(cn).run(model.graph) # make sure it runs + self.assertTrue(source_range_1 == source_range_2) model(x) # check matmuls are not fused diff --git a/test/run_test.py b/test/run_test.py index 0374b9f2561833..5b5970aef26318 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -81,6 +81,7 @@ 'test_xnnpack_integration', 'test_vulkan', 'test_sparse', + 'test_sparse_csr', 'test_quantization', 'test_pruning_op', 'test_spectral_ops', @@ -160,6 +161,7 @@ 'distributed/elastic/utils/util_test', 'distributed/elastic/utils/distributed_test', 'distributed/elastic/multiprocessing/api_test', + 'distributed/_sharding_spec/test_sharding_spec', ] # Tests need to be run with pytest. diff --git a/test/test_autocast.py b/test/test_autocast.py new file mode 100644 index 00000000000000..a722c1a04d54bd --- /dev/null +++ b/test/test_autocast.py @@ -0,0 +1,123 @@ +import collections +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists + +class TestAutocastCPU(TestCase): + def setUp(self): + super(TestAutocastCPU, self).setUp() + self.autocast_lists = AutocastCPUTestLists(torch.device('cpu')) + + def tearDown(self): + del self.autocast_lists + super(TestAutocastCPU, self).tearDown() + + def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None): + # helper to cast args + def cast(val, to_type): + if isinstance(val, torch.Tensor): + return val.to(to_type) if val.is_floating_point() else val + elif isinstance(val, collections.abc.Iterable): + return type(val)(cast(v, to_type) for v in val) + else: + return val + + if add_kwargs is None: + add_kwargs = {} + + self.assertFalse(torch.is_autocast_cpu_enabled()) + with torch.cpu.amp.autocast(): + self.assertTrue(torch.is_autocast_cpu_enabled()) + out_type = out_type if out_type is not None else run_as_type + output = output_method = None + + # Try module.* variant, if requested: + if module is not None and hasattr(module, op): + output = getattr(module, op)(*args, **add_kwargs) + if isinstance(output, torch.Tensor): + self.assertTrue(out_type == output.dtype, + "autocast for torch.{} produced {}, should produce {}" + .format(op, output.dtype, out_type)) + # Try Tensor.* variant: + if hasattr(torch.Tensor, op): + output_method = getattr(args[0], op)(*args[1:], **add_kwargs) + if isinstance(output_method, torch.Tensor): + self.assertTrue(out_type == output_method.dtype, + "autocast for torch.{} produced {}, should produce torch.{}" + .format(op, output_method.dtype, out_type)) + + self.assertTrue((output is not None) or (output_method is not None), + "{} not found as an attribute on either Tensor or the requested module {}".format( + op, module)) + + # Accounts for ops that return Tensors, iterables, and other non-Tensors. + # For example, lstm_cell returns a tuple and equal returns bool. + def compare(first, second): + if isinstance(first, torch.Tensor): + return torch.equal(first, second) + elif isinstance(first, collections.abc.Iterable): + return all(compare(f, s) for f, s in zip(first, second)) + else: + return first == second + + # If both torch.* and Tensor.* variants were found, check outputs are identical + if (output is not None) and (output_method is not None): + self.assertTrue(type(output) == type(output_method)) + comparison = compare(output, output_method) + self.assertTrue(comparison, "torch.{0} result did not match Tensor.{0} result".format(op)) + + # Compare numerics to Python-side "autocasting" that (we expect) does the same thing + # as the C++-side autocasting, and should be bitwise accurate. + output_to_compare = output if output is not None else output_method + with torch.cpu.amp.autocast(enabled=False): + self.assertFalse(torch.is_autocast_cpu_enabled()) + + if module is not None and hasattr(module, op): + control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs) + else: + control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs) + self.assertTrue(type(output_to_compare) == type(control)) + comparison = compare(output_to_compare, control) + self.assertTrue(comparison, "torch.{} result did not match control".format(op)) + self.assertTrue(torch.is_autocast_cpu_enabled()) + self.assertFalse(torch.is_autocast_cpu_enabled()) + + def args_maybe_kwargs(self, op_with_args): + if len(op_with_args) == 2: + return op_with_args[0], op_with_args[1], {} + else: + return op_with_args[0], op_with_args[1], op_with_args[2] + + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: + self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + + def test_autocast_methods_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: + self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type) + + def test_autocast_torch_bf16(self): + for op_with_args in self.autocast_lists.torch_bf16: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs) + + def test_autocast_nn_bf16(self): + for op, args in self.autocast_lists.nn_bf16: + self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn) + + def test_autocast_torch_fp32(self): + for op_with_args in self.autocast_lists.torch_fp32: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs) + + def test_autocast_nn_fp32(self): + for op_with_args in self.autocast_lists.nn_fp32: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs) + + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.autocast_lists.torch_need_autocast_promote: + self._run_autocast_outofplace(op, args, torch.float32) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_autograd.py b/test/test_autograd.py index 785d5da255c1b1..4425777c78a2ff 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5420,6 +5420,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, # deny list for batched grad computation EXCLUDE_BATCHED_GRAD_TESTS = set([ 'test_to_sparse', + 'test_to_sparse_dim', ]) def add_test( diff --git a/test/test_cuda.py b/test/test_cuda.py index ae93b6bd50b36d..cae7f6c3127c53 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1908,13 +1908,16 @@ def test_grad_scaling_unscale(self, dtype=torch.float): for grad in grads: self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) - # Passing lists with mismatched devices or dtypes to a raw + # When passing lists with mismatched dtypes to a raw + # _amp_foreach_non_finite_check_and_unscale_ call, + # it's expected to fall back to single-tensor TensorIterator kernel. + grads = [g.clone(), g.to(dtype=torch.float16)] + torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) + for grad in grads: + self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) + + # Passing lists with mismatched devices to a raw # _amp_foreach_non_finite_check_and_unscale_ call should raise errors. - with self.assertRaisesRegex(RuntimeError, r"must have the same dtype"): - torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(dtype=torch.float16)], - found_inf, - inv_scale) - if TEST_MULTIGPU: with self.assertRaisesRegex(RuntimeError, r"Expected all tensors to be on the same device"): torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")], diff --git a/test/test_datapipe.py b/test/test_datapipe.py index ee6c02ad3d6f15..0c8985b4f1b36e 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -13,13 +13,15 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import (TestCase, run_tests) -from torch.utils.data import \ - (IterDataPipe, MapDataPipe, RandomSampler, DataLoader, - argument_validation, runtime_validation_disabled, runtime_validation) +from torch.utils.data import ( + IterDataPipe, MapDataPipe, RandomSampler, DataLoader, + argument_validation, runtime_validation_disabled, runtime_validation +) -from typing import \ - (Any, Dict, Generic, Iterator, List, NamedTuple, Optional, Tuple, Type, - TypeVar, Set, Union) +from typing import ( + Any, Awaitable, Dict, Generic, Iterator, List, NamedTuple, Optional, Tuple, + Type, TypeVar, Set, Union +) import torch.utils.data.datapipes as dp from torch.utils.data.datapipes.utils.decoder import ( @@ -786,7 +788,7 @@ def __iter__(self) -> Iterator[Tuple[int, str]]: dp2 = DP1(5) self.assertEqual(dp1.type, dp2.type) - with self.assertRaisesRegex(TypeError, r"Can not subclass a DataPipe"): + with self.assertRaisesRegex(TypeError, r"is not a generic class"): class InvalidDP5(DP1[tuple]): # type: ignore[type-arg] def __iter__(self) -> Iterator[tuple]: # type: ignore[override] yield (0, ) @@ -833,7 +835,8 @@ def __iter__(self) -> Iterator[str]: self.assertTrue(issubclass(DP5, IterDataPipe)) dp = DP5() # type: ignore[assignment] - self.assertTrue(dp.type.param == Any) + from torch.utils.data._typing import issubtype + self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param)) class DP6(IterDataPipe[int]): r""" DataPipe with plain Iterator""" @@ -844,6 +847,19 @@ def __iter__(self) -> Iterator: dp = DP6() # type: ignore[assignment] self.assertTrue(dp.type.param == int) + class DP7(IterDataPipe[Awaitable[T_co]]): + r""" DataPipe with abstract base class""" + + self.assertTrue(issubclass(DP6, IterDataPipe)) + self.assertTrue(DP7.type.param == Awaitable[T_co]) + + class DP8(DP7[str]): + r""" DataPipe subclass from a DataPipe with abc type""" + + self.assertTrue(issubclass(DP8, IterDataPipe)) + self.assertTrue(DP8.type.param == Awaitable[str]) + + def test_construct_time(self): class DP0(IterDataPipe[Tuple]): @argument_validation diff --git a/test/test_determination.py b/test/test_determination.py index a95ee534bc696d..e561ba1d61182c 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -74,7 +74,7 @@ def test_test_file(self): ["distributed/rpc/test_process_group_agent"], ) self.assertEqual( - self.determined_tests(["test/quantization/test_quantize.py"]), + self.determined_tests(["test/quantization/test_quantize_eager_ptq.py"]), ["test_quantization"], ) diff --git a/test/test_foreach.py b/test/test_foreach.py index 8d1c26c531c6bd..b15420cbe57769 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1,8 +1,9 @@ +import re import torch import unittest from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, skipCUDAIfRocm, ops) + (instantiate_device_type_tests, dtypes, skipCUDAIfRocm, skipMeta, ops) from torch._six import inf, nan from torch.testing._internal.common_methods_invocations import foreach_unary_op_db @@ -178,7 +179,7 @@ def test_addcdiv(self, device, dtype): return self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv) - @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False)) def test_min_max(self, device, dtype): for N in N_values: tensors1 = self._get_test_data(device, dtype, N) @@ -672,10 +673,19 @@ def test_add_scalar_with_overlapping_tensors(self, device, dtype): res = torch._foreach_add(tensors, 1) self.assertEqual(res, expected) + # note(mkozuki): this test case fails with Meta at least in my local environment. + # The message was + # `AssertionError: NotImplementedError("Could not run 'aten::_foreach_add.Scalar' with arguments from the 'Meta' backend.` + @skipMeta def test_bin_op_scalar_with_different_tensor_dtypes(self, device): tensors = [torch.tensor([1.1], dtype=torch.float, device=device), torch.tensor([1], dtype=torch.long, device=device)] - self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, 1)) + runtime_error = None + try: + torch._foreach_add(tensors, 1) + except RuntimeError as e: + runtime_error = e + self.assertIsNone(runtime_error) # # Ops with list @@ -706,15 +716,6 @@ def test_bin_op_list_error_cases(self, device): with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): bin_op_(tensors1, tensors2) - # Different dtypes - tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)] - tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)] - - with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): - bin_op(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): - bin_op_(tensors1, tensors2) - # different devices if torch.cuda.is_available() and torch.cuda.device_count() > 1: tensor1 = torch.zeros(10, 10, device="cuda:0") @@ -790,6 +791,13 @@ def test_add_list_different_sizes(self, device, dtype): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") @dtypes(*torch.testing.get_all_dtypes()) def test_add_list_slow_path(self, device, dtype): + # 0-strides + tensor1 = torch.rand(10, 10, device=device) + tensor2 = torch.rand(1, device=device).expand_as(tensor1) + res = torch._foreach_add([tensor1], [tensor2]) + torch._foreach_add_([tensor1], [tensor2]) + self.assertEqual(res, [tensor1]) + # different strides tensor1 = torch.zeros(10, 10, device=device, dtype=dtype) tensor2 = torch.ones(10, 10, device=device, dtype=dtype) @@ -806,6 +814,101 @@ def test_add_list_slow_path(self, device, dtype): torch._foreach_add_([tensor1], [tensor2]) self.assertEqual(res, [tensor1]) + # sliced tensor + tensor1 = torch.randn(5, 2, 1, 3, device=device).to(dtype) + tensor2 = torch.randn(5, 2, 1, 3 * 7, device=device).to(dtype)[:, :, :, ::7] + res = torch._foreach_add([tensor1], [tensor2]) + torch._foreach_add_([tensor1], [tensor2]) + self.assertEqual(res, [tensor1]) + + # note: Below three tests (postfixed with `_tensors_on_different_devices`) + # checks whether foreach works with lists of tensors on different devices + # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu]. + @ops(foreach_unary_op_db) + def test_unary_op_tensors_on_different_devices(self, device, dtype, op): + if self.device_type != 'cuda': + self.skipTest('CUDA is necessary for tests with tensors on different devices') + # tensors: ['cuda', 'cpu] + tensors = op.sample_inputs(device, dtype, 2) + tensors[1] = tensors[1].to('cpu') + try: + actual = op.get_method()(tensors) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), str(e)): + [op.ref(t) for t in tensors] + else: + expected = [op.ref(t) for t in tensors] + self.assertEqual(expected, actual) + + try: + op.get_inplace()(tensors) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), str(e)): + [getattr(t, op.ref.__name__ + '_')() for t in tensors] + else: + self.assertEqual(expected, tensors) + + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True)) + def test_binary_op_tensors_on_different_devices(self, device, dtype): + if self.device_type != 'cuda': + self.skipTest('CUDA is necessary for tests with tensors on different devices') + for foreach_op, foreach_op_, native_op in self.bin_ops: + # `tensors1`: ['cuda', 'cpu'] + # `tensors2`: ['cuda', 'cpu'] + _cuda_tensors = self._get_test_data(device, dtype, 2) + _cpu_tensors = self._get_test_data('cpu', dtype, 2) + tensors1, tensors2 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors)) + + try: + actual = foreach_op(tensors1, tensors2) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), re.escape(str(e))): + [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)] + else: + expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)] + self.assertEqual(expected, actual) + try: + foreach_op_(tensors1, tensors2) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), re.escape(str(e))): + [getattr(t1, native_op.__name__ + '_')(t2) for t1, t2 in zip(tensors1, tensors2)] + else: + self.assertEqual(actual, tensors1) + + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True)) + def test_pointwise_op_tensors_on_different_devices(self, device, dtype): + if self.device_type != 'cuda': + self.skipTest('CUDA is necessary for tests with tensors on different devices') + + pointwise_ops = [ + (torch._foreach_addcmul, torch._foreach_addcmul_, torch.addcmul), + (torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv), + ] + for foreach_op, foreach_op_, native_op in pointwise_ops: + # tensors1: ['cuda', 'cpu] + # tensors2: ['cuda', 'cpu] + # tensors3: ['cuda', 'cpu] + _cuda_tensors = self._get_test_data(device, dtype, 3) + _cpu_tensors = self._get_test_data('cpu', dtype, 3) + tensors1, tensors2, tensors3 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors)) + + try: + actual = foreach_op(tensors1, tensors2, tensors3) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), re.escape(str(e))): + expected = [native_op(t1, t2, t3) for t1, t2, t3 in zip(tensors1, tensors2, tensors3)] + else: + expected = [native_op(t1, t2, t3) for t1, t2, t3 in zip(tensors1, tensors2, tensors3)] + self.assertEqual(expected, actual) + try: + foreach_op_(tensors1, tensors2, tensors3) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), re.escape(str(e))): + [getattr(t1, native_op.__name__ + '_')(t2, t3) for t1, t2, t3 in zip(tensors1, tensors3, tensors3)] + else: + self.assertEqual(expected, tensors1) + + instantiate_device_type_tests(TestForeach, globals()) if __name__ == '__main__': diff --git a/test/test_fx.py b/test/test_fx.py index 6414318dbaac45..083d39b9e42984 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1520,6 +1520,7 @@ def test_typename_print(self): b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), type_expr=List[float]) output : torch.fx.Node = graph.output(b) + self.assertTrue('typing.List[float]' in str(graph)) def test_ellipsis(self): @@ -2442,6 +2443,30 @@ def test_annotation_with_future(self): finally: del sys.modules["__future__"] + def test_annotations_empty_tuple(self): + class Foo(torch.nn.Module): + def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): + return "foo" + + traced = torch.fx.symbolic_trace(Foo()) + + x = () + y = ("bar", ()) + + traced(x, y) + + FileCheck().check("_Tuple[()]") \ + .check("typing_Tuple[str,typing_Tuple[()]]") \ + .run(traced.code) + + scripted = torch.jit.script(traced) + + scripted(x, y) + + FileCheck().check("Tuple[()]") \ + .check("Tuple[str, Tuple[()]]") \ + .run(scripted.code) + @skipIfNoTorchVision def test_cpatcher(self): @@ -2604,6 +2629,7 @@ class TestOperatorSignatures(JitTestCase): def test_get_torch_func_signature_exhaustive(self, device, dtype, op): # Sorted and one entry on each line to minimize merge conflicts. known_no_schema = {'cdist', + 'contiguous', 'dstack', 'einsum', 'expand', @@ -2772,6 +2798,7 @@ class TestFunctionalTracing(JitTestCase): "rrelu": CONTROL_FLOW, "selu": CONTROL_FLOW, "silu": CONTROL_FLOW, + "mish": CONTROL_FLOW, "smooth_l1_loss": CONTROL_FLOW, "soft_margin_loss": CONTROL_FLOW, "threshold": CONTROL_FLOW, diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index e3c1023d138f85..4ad874cdf8e1a3 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1313,7 +1313,8 @@ class TestNormalizeOperators(JitTestCase): @ops(op_db, allowed_dtypes=(torch.float,)) def test_normalize_operator_exhaustive(self, device, dtype, op): # Sorted and one entry on each line to minimize merge conflicts. - op_skip = {'einsum', + op_skip = {'contiguous', + 'einsum', 'expand', 'expand_as', 'gradient', diff --git a/test/test_gen_backend_stubs.py b/test/test_gen_backend_stubs.py index 0b70b90f4fe8c7..e1a66c69fe6f51 100644 --- a/test/test_gen_backend_stubs.py +++ b/test/test_gen_backend_stubs.py @@ -161,7 +161,7 @@ def test_backend_autograd_kernel_mismatch_out_functional(self): autograd: - add.out''' output_error = self.get_errors_from_gen_backend_stubs(yaml_str) - self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They can not be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_out is listed under "autograd".''') # noqa: B950 + self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_out is listed under "autograd".''') # noqa: B950 # in an operator group, currently all operators must either be registered to the backend or autograd kernel. # Here, functional and inplace mismatch @@ -174,7 +174,7 @@ def test_backend_autograd_kernel_mismatch_functional_inplace(self): autograd: - add_.Tensor''' output_error = self.get_errors_from_gen_backend_stubs(yaml_str) - self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They can not be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_ is listed under "autograd".''') # noqa: B950 + self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_ is listed under "autograd".''') # noqa: B950 # Currently, the same operator can't be listed under both 'supported' and 'autograd', which would # involve registering the same kernel to both the XLA and AutogradXLA keys. @@ -188,7 +188,7 @@ def test_op_appears_in_supported_and_autograd_lists(self): autograd: - add.Tensor''' output_error = self.get_errors_from_gen_backend_stubs(yaml_str) - self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They can not be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add is listed under "autograd".''') # noqa: B950 + self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add is listed under "autograd".''') # noqa: B950 # unrecognized extra yaml key def test_unrecognized_key(self): @@ -199,7 +199,7 @@ def test_unrecognized_key(self): - abs invalid_key: invalid_val''' output_error = self.get_errors_from_gen_backend_stubs(yaml_str) - self.assertExpectedInline(output_error, ''' contains unexpected keys: invalid_key. Only the following keys are supported: backend, cpp_namespace, supported, autograd''') # noqa: B950 + self.assertExpectedInline(output_error, ''' contains unexpected keys: invalid_key. Only the following keys are supported: backend, cpp_namespace, extra_headers, supported, autograd''') # noqa: B950 if __name__ == '__main__': diff --git a/test/test_jit.py b/test/test_jit.py index 6cd254a5b1d5b7..ed27fb4ab466f1 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -16,9 +16,11 @@ from jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing # noqa: F401 from jit.test_custom_operators import TestCustomOperators # noqa: F401 from jit.test_export_modes import TestExportModes # noqa: F401 +from jit.test_graph_rewrite_passes import TestGraphRewritePasses # noqa: F401 from jit.test_class_type import TestClassType # noqa: F401 from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401 from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401 +from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis # noqa: F401 from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401 from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401 from jit.test_peephole import TestPeephole # noqa: F401 @@ -52,6 +54,7 @@ from jit.test_pdt import TestPDT # noqa: F401 from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401 from jit.test_module_apis import TestModuleAPIs # noqa: F401 +from jit.test_script_profile import TestScriptProfile # noqa: F401 # Torch from torch import Tensor @@ -61,6 +64,7 @@ from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 from torch.nn.utils.rnn import PackedSequence from torch.testing import FileCheck +from torch.testing._internal.common_utils import make_tensor import torch.autograd.profiler import torch.cuda import torch.jit @@ -1270,6 +1274,158 @@ def forward(self, x): FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph) + def test_pattern_based_rewrite_with_source_range_preserved(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super(TestModule1, self).__init__() + + def forward(self, x, y, z, w): + x = x + y + x = x * z + return w - x + + input_pattern = """ + graph(%x, %y, %z, %const): + %t = aten::add(%x, %y, %const) + %o = aten::mul(%t, %z) + return (%o)""" + replacement_pattern = """ + graph(%x, %y, %z, %const): + %o = my::add_mul(%x, %y, %z, %const) + return (%o)""" + scripted_model = torch.jit.script(TestModule1()) + graph = scripted_model.graph + value_mappings = [("o", "t")] + for node in graph.nodes(): + if node.kind() == "aten::add": + source_range_1 = node.sourceRange() + torch._C._jit_pass_custom_pattern_based_rewrite_graph( + input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings) + graph = scripted_model.graph + for node in graph.nodes(): + if node.kind() == "my::add_mul": + source_range_2 = node.sourceRange() + self.assertTrue(source_range_1 == source_range_2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super(TestModule2, self).__init__() + + def forward(self, x, y, z, w): + x = x + y + x = x + z + x = x * z + x = x * w + return x - 2 + + # Check source range preservation for two node transforms add -> my_add + input_pattern = """ + graph(%x, %y, %const): + %o = aten::add(%x, %y, %const) + return (%o)""" + replacement_pattern = """ + graph(%x, %y, %const): + %o = my::add(%x, %y, %const) + return (%o)""" + scripted_model = copy.deepcopy(torch.jit.script(TestModule2())) + graph_copy = scripted_model.graph.copy() + value_mappings = [("o", "o")] + source_range_add_1 = None + for node in graph_copy.nodes(): + if source_range_add_1 is None and node.kind() == "aten::add": + source_range_add_1 = node.sourceRange() + if source_range_add_1 is not None and node.kind() == "aten::add": + source_range_add_2 = node.sourceRange() + torch._C._jit_pass_custom_pattern_based_rewrite_graph( + input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) + source_range_my_add_1 = None + for node in graph_copy.nodes(): + if source_range_my_add_1 is None and node.kind() == "my::add": + source_range_my_add_1 = node.sourceRange() + if source_range_my_add_1 is not None and node.kind() == "my::add": + source_range_my_add_2 = node.sourceRange() + self.assertTrue(source_range_add_1 == source_range_my_add_1) + self.assertTrue(source_range_add_2 == source_range_my_add_2) + + # Check source range preservation for add-add -> double_add transform + # fuse nodes + input_pattern = """ + graph(%x, %y, %z, %const): + %t = aten::add(%x, %y, %const) + %o = aten::add(%t, %z, %const) + return (%o)""" + replacement_pattern = """ + graph(%x, %y, %z, %const): + %o = my::double_add(%x, %y, %z, %const) + return (%o)""" + scripted_model = torch.jit.script(TestModule2()) + graph_copy = scripted_model.graph.copy() + value_mappings = [("o", "t")] + source_range_1 = None + source_range_2 = None + for node in graph_copy.nodes(): + if node.kind() == "aten::add": + source_range_1 = node.sourceRange() + break + torch._C._jit_pass_custom_pattern_based_rewrite_graph( + input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) + for node in graph_copy.nodes(): + if node.kind() == "my::double_add": + source_range_2 = node.sourceRange() + self.assertTrue(source_range_1 == source_range_2) + + # Check source range preservation for mul -> add + add transform + # split node + input_pattern = """ + graph(%x, %y): + %t = aten::mul(%x, %y) + return (%t)""" + replacement_pattern = """ + graph(%x, %y): + %t = my::add(%x, %y) + %o = my::add(%t, %y) + return (%o)""" + scripted_model = torch.jit.script(TestModule2()) + graph_copy = scripted_model.graph.copy() + value_mappings = [("t", "t"), ("o", "t")] + source_range_mul_1 = None + for node in graph_copy.nodes(): + if source_range_mul_1 is None and node.kind() == "aten::mul": + source_range_mul_1 = node.sourceRange() + if source_range_mul_1 is not None and node.kind() == "aten::mul": + source_range_mul_2 = node.sourceRange() + torch._C._jit_pass_custom_pattern_based_rewrite_graph( + input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) + source_range_add_1 = None + for node in graph_copy.nodes(): + if source_range_add_1 is None and node.kind() == "my::add": + source_range_add_1 = node.sourceRange() + if source_range_add_1 is not None and node.kind() == "my::add": + source_range_add_2 = node.sourceRange() + self.assertTrue(source_range_mul_1 == source_range_add_1) + self.assertTrue(source_range_mul_2 == source_range_add_2) + + # Check lack of source range preservation for mul-mul-> double_mul transform + input_pattern = """ + graph(%x, %y, %z): + %t = aten::mul(%x, %y) + %o = aten::mul(%t, %z) + return (%o)""" + replacement_pattern = """ + graph(%x, %y, %z): + %o = my::double_mul(%x, %y, %z) + return (%o)""" + scripted_model = torch.jit.script(TestModule2()) + graph_copy = scripted_model.graph.copy() + for node in graph_copy.nodes(): + if node.kind() == "aten::mul": + source_range_1 = node.sourceRange() + torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy) + for node in graph_copy.nodes(): + if node.kind() == "my::double_mul": + source_range_2 = node.sourceRange() + self.assertFalse(source_range_1 == source_range_2) + def test_expand_quantlint(self): pass @@ -1580,15 +1736,27 @@ def fn_out(real, img, out): self.checkScript(fn_out, (real, img, out, )) def test_einsum(self): - def outer(x, y): + def check(fn, jitted, *args): + self.assertGraphContains(jitted.graph, kind='aten::einsum') + self.assertEqual(fn(*args), jitted(*args)) + + def equation_format(x, y): return torch.einsum('i,j->ij', (x, y)) - traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5))) - script = torch.jit.script(outer) - x, y = torch.randn(10), torch.randn(2) - for fn in [traced, script]: - self.assertGraphContains(fn.graph, kind='aten::einsum') - self.assertEqual(fn(x, y), outer(x, y)) + def sublist_format(x, y): + return torch.einsum(x, [0], y, [1], [0, 1]) + + # Sublist format cannot be scripted because it is + # a NumPy API only feature + with self.assertRaises(RuntimeError): + torch.jit.script(sublist_format) + + x = make_tensor((5,), 'cpu', torch.float32) + y = make_tensor((10,), 'cpu', torch.float32) + + check(equation_format, torch.jit.script(equation_format), x, y) + check(equation_format, torch.jit.trace(equation_format, (x, y)), x, y) + check(sublist_format, torch.jit.trace(sublist_format, (x, y)), x, y) def test_python_ivalue(self): # Test if pure python object can be hold as IValue and conversion @@ -15337,6 +15505,7 @@ def forward(self, x): # jit doesn't support sparse tensors. 'test_to_sparse', + 'test_to_sparse_dim', } EXCLUDE_TYPE_CHECK = { diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index a1b42bfcc06cb0..26e30d33f85411 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1,9 +1,11 @@ import operator import unittest import contextlib +import math import torch import torch.nn.functional as F from torch.testing import FileCheck +from typing import List # these needs to be set before `common_utils` # infers `GRAPH_EXECUTOR`. @@ -15,10 +17,13 @@ torch._C._jit_set_profiling_mode(True) from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ - enable_profiling_mode_for_profiling_tests + enable_profiling_mode_for_profiling_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests + from textwrap import dedent from itertools import product, permutations @@ -1267,6 +1272,7 @@ def apply(fn): torch.trunc, torch.frac, F.hardshrink, + F.leaky_relu, lambda x: torch.threshold(x, 0, -10), lambda x: torch.clamp(x, -10, 10), ] @@ -1803,5 +1809,194 @@ def eager(x, y, z: int): z = 2 script = self.checkScript(eager, (x, y, z)) + def test_dynamic_cat(self): + with inline_fusion_groups(): + @torch.jit.script + def repro(xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor]): + return [ + torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1) + for x, y, z in zip(xs, ys, zs) + ] + for _ in range(3): + N = 3 + xs = [torch.ones(21) for _ in range(N)] + # Note: concat of ys and zs will have the same size for each + # pair, even though the individual ys and zs do not. + ys = [torch.ones(N - i) for i in range(N)] + zs = [torch.ones(i) for i in range(N)] + repro(xs, ys, zs) + + +works_list = [ + '__radd__', + '__rdiv__', + '__rmul__', + 'abs', + 'acos', + 'add', + 'addcmul', + 'asin', + 'atan2', + 'atan', + 'ceil', + 'clamp', + 'clamp.scalar', + 'cos', + 'cosh', + 'div.no_rounding_mode', + 'div.true_rounding', + 'eq', + 'erf', + 'erfc', + 'exp', + 'expand', + 'expm1', + 'floor', + 'ge', + 'gt', + 'le', + 'lerp', + 'lgamma', + 'log10', + 'log1p', + 'log2', + 'log', + 'lt', + 'masked_fill', + 'max.binary', + 'min.binary', + 'mm', + 'mul', + 'ne', + 'neg', + 'nn.functional.gelu', + 'nn.functional.hardshrink', + 'nn.functional.hardswish', + 'nn.functional.hardtanh', + 'nn.functional.leaky_relu', + 'pow', + 'reciprocal', + 'round', + 'rsqrt', + 'sigmoid', + 'sin', + 'sinh', + 'sqrt', + 'sub', + 'sum', + 'tan', + 'tanh', + 'transpose', + 'trunc', + 'unsqueeze', +] + +known_failures = [ + 'matmul', + 'permute', + 'frac', + '__rmatmul__' +] + +# If your OpInfo test causes this test to fail, add it here +skip_ops = [ +] + +def get_name(op): + l = [op.name] + if op.variant_test_name != '': + l.append(op.variant_test_name) + return '.'.join(l) + +class TestNNCOpInfo(TestCase): + def te_compile(self, device, dtype, op): + # If adding new OpInfo tests cause this test to fail, add it into here + skip_ops = [] + if op.name in skip_ops: + return + sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) + is_compiling = False + for sample_input in sample_inputs_itr: + arg_values = [sample_input.input] + list(sample_input.args) + kwarg_values = sample_input.kwargs + param_names = [] + param_values = [] + fx_args = [] + for idx, v in enumerate(arg_values): + if isinstance(v, torch.Tensor): + param_names.append(f"arg_{idx}") + param_values.append(v) + fx_args.append(param_names[-1]) + else: + fx_args.append(f'{repr(v)}') + + for k, v in kwarg_values.items(): + if isinstance(v, torch.Tensor): + param_names.append(k) + param_values.append(v) + fx_args.append(f'{k} = {k}') + else: + fx_args.append(f'{k} = {repr(v)}') + + code = f""" +def f({', '.join(param_names)}): + return op.op({', '.join(fx_args)})""" + g = {'torch': torch, 'inf' : math.inf, 'op': op} + exec(code, g) + f = g['f'] + f.__module__ = 'test' + out = f(*param_values) + + # NNC currently oftens segfault when asked to lower ops with 0-dim tensor outputs + if isinstance(out, torch.Tensor) and out.dim() == 0: + continue + else: + is_compiling = True + + ts_g = torch.jit.trace(f, param_values) + kernel = torch._C._te.TensorExprKernel(ts_g.graph) + correct_val = f(*param_values) + self.assertEqual(kernel.run(tuple(param_values)), correct_val) + self.assertEqual(kernel.fallback(tuple(param_values)), correct_val) + + # If all sample inputs have scalar output, we won't have tested it and + # we consider the op to be not working + if not is_compiling: + raise RuntimeError("Skipped all inputs") + + @onlyCPU + @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") + @ops([op for op in op_db if get_name(op) in works_list], allowed_dtypes=(torch.float,)) + def test_working(self, device, dtype, op): + self.te_compile(device, dtype, op) + + @onlyCPU + @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") + @ops([op for op in op_db if get_name(op) in known_failures], allowed_dtypes=(torch.float,)) + def test_failures(self, device, dtype, op): + try: + self.te_compile(device, dtype, op) + except Exception as e: + pass + else: + raise RuntimeError("Expected test to fail. If it now works, move op into works_list") + + @onlyCPU + @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") + @ops([op for op in op_db if get_name(op) not in works_list + known_failures], allowed_dtypes=(torch.float,)) + def test_unsupported(self, device, dtype, op): + if get_name(op) in skip_ops: + return + try: + self.te_compile(device, dtype, op) + except Exception as e: + pass + else: + raise RuntimeError("Expected test to fail. If it now works, move op into works_list") + + +only_for = ("cpu", "cuda") +instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) + if __name__ == '__main__': run_tests() diff --git a/test/test_linalg.py b/test/test_linalg.py index a0651e124ab14b..0fdd82218be188 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -33,6 +33,14 @@ import scipy class TestLinalg(TestCase): + def setUp(self): + super(self.__class__, self).setUp() + torch.backends.cuda.matmul.allow_tf32 = False + + def tearDown(self): + torch.backends.cuda.matmul.allow_tf32 = True + super(self.__class__, self).tearDown() + exact_dtype = True @dtypes(torch.float, torch.cfloat) @@ -4445,23 +4453,24 @@ def test_qr_error_cases(self, device, dtype): @dtypes(torch.double, torch.cdouble) def test_einsum(self, device, dtype): - def check(equation, *operands): - ref = np.einsum(equation, *[operand.cpu().numpy() for operand in operands]) - res = torch.einsum(equation, operands) - self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) + def check(*args): + np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args] + ref = np.einsum(*np_args) + res = torch.einsum(*args) + self.assertEqual(torch.from_numpy(np.array(ref)), res) # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f - x = torch.rand(5, device=device, dtype=dtype) - y = torch.rand(7, device=device, dtype=dtype) - A = torch.randn(3, 5, device=device, dtype=dtype) - B = torch.randn(2, 5, device=device, dtype=dtype) - C = torch.randn(2, 3, 5, device=device, dtype=dtype) - D = torch.randn(2, 5, 7, device=device, dtype=dtype) - E = torch.randn(7, 9, device=device, dtype=dtype) - F = torch.randn(2, 3, 3, 5, device=device, dtype=dtype) - G = torch.randn(5, 4, 6, device=device, dtype=dtype) - H = torch.randn(4, 4, device=device, dtype=dtype) - I = torch.rand(2, 3, 2, device=device, dtype=dtype) + x = make_tensor((5,), device, dtype) + y = make_tensor((7,), device, dtype) + A = make_tensor((3, 5), device, dtype) + B = make_tensor((2, 5), device, dtype) + C = make_tensor((2, 3, 5), device, dtype) + D = make_tensor((2, 5, 7), device, dtype) + E = make_tensor((7, 9), device, dtype) + F = make_tensor((2, 3, 3, 5), device, dtype) + G = make_tensor((5, 4, 6), device, dtype) + H = make_tensor((4, 4), device, dtype) + I = make_tensor((2, 3, 2), device, dtype) # Vector operations check('i->', x) # sum @@ -4492,20 +4501,20 @@ def check(equation, *operands): check("ii", H) # trace check("ii->i", H) # diagonal check('iji->j', I) # non-contiguous trace - check('ngrg...->nrg...', torch.rand((2, 1, 3, 1, 4), device=device, dtype=dtype)) + check('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), device, dtype)) # Test ellipsis check("i...->...", H) check("ki,...k->i...", A.t(), B) check("k...,jk->...", A.t(), B) check('...ik, ...j -> ...ij', C, x) - check('Bik,k...j->i...j', C, torch.rand(5, 3, device=device, dtype=dtype)) - check('i...j, ij... -> ...ij', C, torch.rand(2, 5, 2, 3, device=device, dtype=dtype)) + check('Bik,k...j->i...j', C, make_tensor((5, 3), device, dtype)) + check('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), device, dtype)) # torch.bilinear with noncontiguous tensors - l = torch.randn(10, 5, device=device, dtype=dtype).transpose(0, 1) - r = torch.randn(20, 5, device=device, dtype=dtype).transpose(0, 1) - w = torch.randn(15, 10, 20, device=device, dtype=dtype) + l = make_tensor((5, 10), device, dtype, noncontiguous=True) + r = make_tensor((5, 20), device, dtype, noncontiguous=True) + w = make_tensor((15, 10, 20), device, dtype) check("bn,anm,bm->ba", l, w, r) # with strided tensors @@ -4545,7 +4554,7 @@ def check(equation, *operands): labels.insert(ell_index, "...") equation += ''.join(labels) + ',' - ops.append(torch.rand(sizes, device=device, dtype=dtype)) + ops.append(make_tensor(sizes, device, dtype)) equation = equation[:-1] # Test with implicit output @@ -4563,8 +4572,8 @@ def check(equation, *operands): def test_einsum_corner_cases(self, device): def check(equation, *operands, expected_output): - tensors = [torch.tensor(operand, dtype=torch.float32, device=device) if not isinstance(operand, tuple) - else torch.rand(operand, dtype=torch.float32, device=device) for operand in operands] + tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple) + else make_tensor(operand, device, torch.float32) for operand in operands] output = torch.einsum(equation, tensors) self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) @@ -4602,33 +4611,68 @@ def check(equation, *operands, expected_output): check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]]) def test_einsum_error_cases(self, device): - def check(equation, operands, regex, exception=RuntimeError): - with self.assertRaisesRegex(exception, r'einsum\(\): ' + regex): - torch.einsum(equation, operands) - - x = torch.rand(2) - y = torch.rand(2, 3) - - check('', [], r'must provide at least one operand') - check('. ..', [x], r'found \'.\' for operand 0 that is not part of any ellipsis') - check('... ...', [x], r'found \'.\' for operand 0 for which an ellipsis was already found') - check('1', [x], r'operand subscript must be in \[a-zA-Z\] but found 1 for operand 0') - check(',', [x], r'fewer operands were provided than specified in the equation') - check('', [x, x], r'more operands were provided than specified in the equation') - check('', [x], r'the number of subscripts in the equation \(0\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai', [x], r'the number of subscripts in the equation \(2\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai...', [x], r'the number of subscripts in the equation \(2\) is more than the number ' - r'of dimensions \(1\) for operand 0') - check('a->... .', [x], r'found \'.\' for output but an ellipsis \(...\) was already found') - check('a->..', [x], r'found \'.\' for output that is not part of any ellipsis \(...\)') - check('a->1', [x], r'subscripts must be in \[a-zA-Z\] but found 1 for the output') - check('a->aa', [x], r'output subscript a appears more than once in the output') - check('a->i', [x], r'output subscript i does not appear in the equation for any input operand') - check('aa', [y], r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') - check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' - r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') + def check(*args, regex, exception=RuntimeError): + with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex): + torch.einsum(*args) + + x = make_tensor((2,), device, torch.float32) + y = make_tensor((2, 3), device, torch.float32) + + check('', [], regex=r'at least one operand', exception=ValueError) + check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis') + check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found') + check('1', [x], regex=r'invalid subscript given at index 0') + check(',', [x], regex=r'fewer operands were provided than specified in the equation') + check('', [x, x], regex=r'more operands were provided than specified in the equation') + check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number ' + r'of dimensions \(1\) for operand 0 and no ellipsis was given') + check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number ' + r'of dimensions \(1\) for operand 0 and no ellipsis was given') + check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number ' + r'of dimensions \(1\) for operand 0') + check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found') + check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)') + check('a->1', [x], regex=r'invalid subscript given at index 3') + check('a->aa', [x], regex=r'output subscript a appears more than once in the output') + check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand') + check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') + check('a, ba', [x, y], regex=r'operands do not broadcast with remapped shapes \[original->remapped\]: ' + r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') + + check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError) + check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError) + + @dtypes(torch.double, torch.cdouble) + def test_einsum_sublist_format(self, device, dtype): + def check(*args): + np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args] + ref = np.einsum(*np_args) + res = torch.einsum(*args) + self.assertEqual(torch.from_numpy(np.array(ref)), res) + + x = make_tensor((5,), device, dtype) + y = make_tensor((7,), device, dtype) + A = make_tensor((3, 5), device, dtype) + B = make_tensor((2, 5), device, dtype) + C = make_tensor((2, 1, 3, 1, 4), device, dtype) + + check(x, [0]) + check(x, [0], []) + check(x, [0], y, [1], [0, 1]) + check(A, [0, 1], [1, 0]) + check(A, [0, 1], x, [1], [0]) + check(A, [0, 1], B, [2, 1]) + check(A, [0, 1], B, [2, 1], [0, 2]) + check(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis]) + check(A.t(), [0, 1], B, [Ellipsis, 0]) + check(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis]) + check(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis]) + + # torch.bilinear with noncontiguous tensors + l = make_tensor((5, 10), device, dtype, noncontiguous=True) + r = make_tensor((5, 20), device, dtype, noncontiguous=True) + w = make_tensor((15, 10, 20), device, dtype) + check(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2]) def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, device, dtype): diff --git a/test/test_module_init.py b/test/test_module_init.py index 64846ae20ea15f..80bb7c7c6662fe 100644 --- a/test/test_module_init.py +++ b/test/test_module_init.py @@ -104,6 +104,7 @@ def build_constructor_arg_db(): torch.nn.MaxUnpool1d: ((5,), {}), torch.nn.MaxUnpool2d: ((5,), {}), torch.nn.MaxUnpool3d: ((5,), {}), + torch.nn.Mish: ((), {}), torch.nn.ModuleDict: ((), {}), torch.nn.ModuleList: ((), {}), torch.nn.MultiLabelMarginLoss: ((), {}), diff --git a/test/test_nn.py b/test/test_nn.py index 88e0c669ac5dbb..f43e9af0a97a91 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5504,6 +5504,7 @@ def test_Conv2d_groups_nobias(self): # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16 # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686 # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 + @skipIfRocm def test_Conv2d_groups_nobias_v2(self): torch.manual_seed(123) dev_dtypes = [("cpu", torch.float)] @@ -12866,55 +12867,56 @@ def test_linear_empty(self, device): inp = torch.randn(0, 7, device=device) self._test_module_empty_input(mod, inp) - def test_one_hot(self, device): + @dtypes(torch.uint8, torch.long, torch.bool) + def test_one_hot(self, device, dtype): if self.device_type != 'cuda': # cuda throws device assert for invalid data with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) + torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device, dtype=dtype), -1) with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3) + torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device, dtype=dtype), 3) - t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device)) + t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), dtype=dtype) expected = torch.tensor([[0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], - [1, 0, 0, 0, 0]], device=device) + [1, 0, 0, 0, 0]], device=device, dtype=dtype) self.assertEqual(t, expected) - t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1) + t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1, dtype=dtype) expected = torch.tensor([[0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], - [1, 0, 0, 0, 0]], device=device) + [1, 0, 0, 0, 0]], device=device, dtype=dtype) self.assertEqual(t, expected) - t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6) + t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6, dtype=dtype) expected = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 0]], device=device) + [1, 0, 0, 0, 0, 0]], device=device, dtype=dtype) self.assertEqual(t, expected) - t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device)) + t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device), dtype=dtype) expected = torch.tensor([[[0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], [[0, 1, 0, 0, 0], - [1, 0, 0, 0, 0]]], device=device) + [1, 0, 0, 0, 0]]], device=device, dtype=dtype) self.assertEqual(t, expected) - t = torch.nn.functional.one_hot(torch.tensor(4, device=device)) - expected = torch.tensor([0, 0, 0, 0, 1], device=device) + t = torch.nn.functional.one_hot(torch.tensor(4, device=device), dtype=dtype) + expected = torch.tensor([0, 0, 0, 0, 1], device=device, dtype=dtype) self.assertEqual(t, expected) - t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100) - expected = torch.empty([4, 0, 100], dtype=torch.long) + t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100, dtype=dtype) + expected = torch.empty([4, 0, 100], dtype=dtype) self.assertEqual(t, expected) with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device)) + torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), dtype=dtype) with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) + torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2, dtype=dtype) def test_nn_scalars(self, device): # One off tests to ensure scalars from nn.yaml are properly applied @@ -16082,6 +16084,12 @@ def test_silu_inplace_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.silu(x, inplace=True) + @onlyOnCPUAndCUDA + def test_mish_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.mish(x, inplace=True) + def test_softplus_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -16222,8 +16230,7 @@ def forward(self, x): # Test meta module instantiation. input = torch.randn(5, 10, device=device, dtype=dtype) m = MyModule(10, 1, device='meta', dtype=dtype) - with self.assertRaises(NotImplementedError): - m(input) + m(input) # Test materializing meta module on a real device. m.to_empty(device=device) @@ -16234,8 +16241,19 @@ def forward(self, x): # Test creating meta module from materialized module. m.to_empty(device='meta') - with self.assertRaises(NotImplementedError): - m(input) + m(input) + + @skipMeta + def test_skip_init(self, device): + torch.manual_seed(1) + m_initialized = torch.nn.Linear(5, 1) + m_initialized.to(device) + + torch.manual_seed(1) + m_uninitialized = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device=device) + + self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device) + self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight)) class TestModuleGlobalHooks(TestCase): diff --git a/test/test_ops.py b/test/test_ops.py index f36fc8bd514ce7..ea160b65ae1b88 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -321,7 +321,10 @@ def _test_consistency_helper(samples, variants): cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input if variant in inplace_ops and sample.broadcasts_input: - with self.assertRaises(RuntimeError): + with self.assertRaises(RuntimeError, + msg=('inplace variant either incorrectly allowed ' + 'resizing or you have marked the sample {}' + ' incorrectly with `broadcasts_self=True'.format(sample.summary()))): variant_forward = variant(cloned, *sample.args, **sample.kwargs) diff --git a/test/test_overrides.py b/test/test_overrides.py index eaea915e381b6d..cf18420f88491b 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -593,7 +593,7 @@ def instance_gen(): func_args.append(None) elif t == 'ScalarType': func_args.append(torch.float32) - elif t == 'std::string': + elif t == 'c10::string_view': func_args.append('') else: raise RuntimeError(f"Unsupported argument type {t} for {arg['name']} of function {func}") diff --git a/test/test_quantization.py b/test/test_quantization.py index d30854c027f833..68ea8b1618d5a7 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -23,9 +23,6 @@ from quantization.test_quantized_module import TestStaticQuantizedModule # noqa: F401 from quantization.test_quantized_module import TestDynamicQuantizedModule # noqa: F401 -# Quantization Aware Training -from quantization.test_qat_module import TestQATModule # noqa: F401 - # Quantization specific fusion passes from quantization.test_fusion_passes import TestFusionPasses # noqa: F401 @@ -41,19 +38,23 @@ from quantization.test_workflow_module import TestDistributed # noqa: F401 # Workflow -# 1. Eager mode quantization -from quantization.test_quantize import TestPostTrainingStatic # noqa: F401 -from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401 -from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401 -from quantization.test_quantize import TestEagerModeOps # noqa: F401 -from quantization.test_quantize import TestEagerModeQATOps # noqa: F401 - -# TODO: merge with other tests in test_quantize.py? -from quantization.test_quantize import TestFunctionalModule # noqa: F401 -from quantization.test_quantize import TestFusion # noqa: F401 -from quantization.test_quantize import TestModelNumerics # noqa: F401 -from quantization.test_quantize import TestQuantizeONNXExport # noqa: F401 -from quantization.test_quantize import TestDeprecatedJitQuantized # noqa: F401 +# 1. Eager mode post training quantization +from quantization.test_quantize_eager_ptq import TestPostTrainingStatic # noqa: F401 +from quantization.test_quantize_eager_ptq import TestPostTrainingDynamic # noqa: F401 +from quantization.test_quantize_eager_ptq import TestEagerModeActivationOps # noqa: F401 +from quantization.test_quantize_eager_ptq import TestFunctionalModule # noqa: F401 +from quantization.test_quantize_eager_ptq import TestModelNumerics # noqa: F401 +from quantization.test_quantize_eager_ptq import TestQuantizeONNXExport # noqa: F401 + +# Eager mode quantization aware training +from quantization.test_quantize_eager_qat import TestQuantizationAwareTraining # noqa: F401 +from quantization.test_quantize_eager_qat import TestQATActivationOps # noqa: F401 +from quantization.test_quantize_eager_qat import TestConvBNQATModule # noqa: F401 + +# Eager mode fusion passes +from quantization.test_fusion import TestFusion # noqa: F401 + +from quantization.test_deprecated_jit_quant import TestDeprecatedJitQuantized # noqa: F401 # 2. Graph mode quantization from quantization.test_quantize_jit import TestQuantizeJit # noqa: F401 diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 6b8c010024d5f6..ce858b6263b74a 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -57,6 +57,14 @@ def test_sort(self, device): x = torch.rand(4, SIZE, device=device) res1val, res1ind = torch.sort(x) + # Test inplace + y = x.clone() + y_inds = torch.tensor((), dtype=torch.int64, device=device) + torch.sort(y, out=(y, y_inds)) + x_vals, x_inds = torch.sort(x) + self.assertEqual(x_vals, y) + self.assertEqual(x_inds, y_inds) + # Test use of result tensor res2val = torch.tensor((), device=device) res2ind = torch.tensor((), device=device, dtype=torch.long) diff --git a/test/test_sparse.py b/test/test_sparse.py index c201704cdf5fa6..5b9b873fe646d2 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -246,6 +246,20 @@ def test_sparse_sum(): ref = test_sparse_sum() self.assertTrue(ref.expired()) + @dtypes(torch.double) + def test_ctor_large_sizes(self, device, dtype): + # Test that integer overflow is detected when computing numel + # of a sparse tensor with large dimensions (gh-57416). Notice + # that numel is computed internally when constructing a + # tensor, hence the overflow may appear during the tensor + # construction step. + N = 100000 + indices = torch.tensor([[N, N - 1]] * 4, dtype=torch.int64, device=device) + values = torch.tensor([1, 2], dtype=dtype, device=device) + self.assertRaises(RuntimeError, + lambda: torch.sparse_coo_tensor( + indices, values, (N + 1,) * 4, device=device)) + @dtypes(torch.double, torch.cdouble) def test_ctor_size_checks(self, device, dtype): indices = self.index_tensor([ diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 949b5fc44d6e5b..4c80e922d249b6 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1,6 +1,9 @@ import torch import warnings -from torch.testing._internal.common_utils import TestCase, run_tests, load_tests, coalescedonoff +import unittest +import random +from torch.testing._internal.common_utils import \ + (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCPU) @@ -81,7 +84,10 @@ def test_factory_size_check(self, device, dtype): size, dtype=dtype, device=device) @onlyCPU + @unittest.skip("see: https://github.com/pytorch/pytorch/issues/58762") def test_sparse_csr_print(self, device): + orig_maxDiff = self.maxDiff + self.maxDiff = None shape_nnz = [ ((10, 10), 10), ((100, 10), 10), @@ -112,6 +118,7 @@ def test_sparse_csr_print(self, device): printed.append('') printed.append('') self.assertExpected('\n'.join(printed)) + self.maxDiff = orig_maxDiff @onlyCPU def test_sparse_csr_from_dense(self, device): @@ -185,6 +192,7 @@ def test_coo_to_csr_convert(self, device, dtype, coalesced): self.assertEqual(coo.matmul(vec), csr.matmul(vec)) @onlyCPU + @unittest.skipIf(IS_MACOS or IS_WINDOWS, "MKL doesn't work on windows or mac") @dtypes(torch.float, torch.double) def test_mkl_matvec_warnings(self, device, dtype): if torch.has_mkl: @@ -226,6 +234,34 @@ def test_csr_matvec(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "mv: expected"): csr.matmul(bad_vec) + @onlyCPU + @dtypes(torch.double) + def test_mm(self, device, dtype): + def test_shape(di, dj, dk, nnz): + x = self.genSparseCSRTensor((di, dj), nnz, device=device, dtype=dtype, index_dtype=torch.int32) + t = torch.randn(di, dk, dtype=dtype, device=device) + y = torch.randn(dj, dk, dtype=dtype, device=device) + alpha = random.random() + beta = random.random() + + # res = beta * t + alpha * (x @ y) + res = torch.addmm(t, x, y, beta=beta, alpha=alpha) + expected = torch.addmm(t, x.to_dense(), y, beta=beta, alpha=alpha) + self.assertEqual(res, expected) + + res = torch.addmm(t, x, y) + expected = torch.addmm(t, x.to_dense(), y) + self.assertEqual(res, expected) + + res = torch.mm(x, y) + expected = torch.mm(x.to_dense(), y) + self.assertEqual(res, expected) + + for i in range(2, 5): + for j in range(2, 8): + for k in range(2, 8): + test_shape(i, j, k, i * j // 2) + @onlyCPU @dtypes(*torch.testing.floating_types()) def test_coo_csr_conversion(self, device, dtype): diff --git a/test/test_torch.py b/test/test_torch.py index c063fe2e20c136..7f277d2a130a77 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2398,8 +2398,11 @@ def test_has_internal_overlap(self): # Check for zero strided, size 1 axis, in non-contiguous storage (gh-33812) c = torch.randn(10).as_strided([2, 1, 5], [1, 0, 2]) + self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_NO) + c = torch.randn(2, 1, 10)[::2].as_strided((2, 1, 5), (10, 0, 2)) self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_TOO_HARD) + def test_allow_tensor_metadata_change(self): def do_test(t): with self.assertRaisesRegex( @@ -3946,10 +3949,10 @@ def backward_func(slf, device): def test_nondeterministic_alert_scatter_add(self, device): def test_func(op_call): - input = torch.randn(10, device=device) + input = torch.randn(5, 4, device=device) dim = 0 - index = torch.tensor([3], device=device) - src = torch.randn(1, device=device) + index = torch.tensor([[3]], device=device) + src = torch.tensor([[1.0]], device=device) @expectedAlertNondeterministic('scatter_add_cuda_kernel', 'cuda') def forward_func(slf, device): @@ -4166,6 +4169,24 @@ def test_gather_backward_deterministic_path(self, device) -> None: def test_gather_backward_one_dim(self, device) -> None: self._test_gather_backward_one_dim(device, False) + @onlyOnCPUAndCUDA + def test_scatter_add_one_dim_deterministic(self, device) -> None: + with DeterministicGuard(True): + m = random.randint(20, 30) + elems = random.randint(2000 * m, 3000 * m) + dim = 0 + src = torch.randn(elems, device=device) + idx = torch.randint(m, (elems,), device=device) + + x = torch.zeros(m, device=device) + res = x.scatter_add(dim, idx, src) + + expected = torch.zeros(m, device=device) + for i in range(elems): + expected[idx[i]] += src[i] + + self.assertEqual(res, expected, atol=0, rtol=0) + @dtypes(*torch.testing.get_all_fp_dtypes()) def test_log_normal(self, device, dtype): a = torch.tensor([10], dtype=dtype, device=device).log_normal_() @@ -4180,15 +4201,30 @@ def test_geometric(self, device, dtype): def test_repeat_interleave(self, device): y = torch.tensor([[1, 2], [3, 4]], device=device) + # exercise single argument function signature + temp = y.repeat_interleave(2) + self.assertEqual(torch.Size([8]), temp.size()) + for dtype in [torch.int, torch.long]: + lengths = torch.tensor([1, 2], dtype=dtype, device=device) + output_size = torch.sum(lengths) a = torch.repeat_interleave( y, - torch.tensor([1, 2], dtype=dtype, device=device), + lengths, dim=0, ) self.assertEqual(a.dtype, y.dtype) self.assertEqual(a.size(), torch.Size([3, 2])) + a_with_output = torch.repeat_interleave( + y, + lengths, + dim=0, + output_size=output_size, + ) + self.assertEqual(a_with_output.dtype, y.dtype) + self.assertEqual(a_with_output.size(), torch.Size([3, 2])) + @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False))) def test_bernoulli_p(self, device, dtype): @@ -5679,11 +5715,21 @@ def test_masked_scatter(self, device, dtype): with self.assertRaises(RuntimeError): dest.masked_scatter_(mask, src) + # empty tensor + dest = torch.empty((5, 0, 5), dtype=dt, device=device) + mask = torch.ones_like(dest, dtype=maskType, device=device) + src = torch.empty((0,), dtype=dt, device=device) + dest.masked_scatter_(mask, src) + + dest = torch.empty((5, 0, 5), dtype=dt, device=device) + mask = torch.ones((5, 1, 5), dtype=maskType, device=device) + src = torch.empty((0,), dtype=dt, device=device) + dest.masked_scatter_(mask, src) if self.device_type != 'cuda': - self.assertEqual(len(w), 3) + self.assertEqual(len(w), 5) else: - self.assertEqual(len(w), 2) + self.assertEqual(len(w), 4) warn = 'masked_scatter_ received a mask with dtype torch.uint8,' for wi in w: @@ -8021,9 +8067,6 @@ def tmp(dtype, device): ('size', '', _new_t((1, 2, 3, 4)), lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('size', 'dim', _new_t((1, 2, 3, 4)), lambda t, d: [1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('size', 'neg_dim', _new_t((1, 2, 3, 4)), lambda t, d: [-2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('split', '', _small_3d, lambda t, d: [2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('split', 'dim', _small_3d, lambda t, d: [2, 1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('split', 'neg_dim', _small_3d, lambda t, d: [2, -3], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('t', '', _new_t((1, 2)), lambda t, d: [],), ('take', '', _new_t((3, 4)), lambda t, d: [torch.LongTensor([[0], [-2]]).to(device=d)], @@ -8201,6 +8244,10 @@ class TestTensorDeviceOps(TestCase): class TestTorch(AbstractTestCases._TestTorchMixin): exact_dtype = True + def test_tensor_ctor_scalar(self): + x = torch.Tensor(torch.tensor(1.0)) + self.assertEqual(x, torch.tensor(1.0)) + def test_deepcopy_gradient(self): from copy import deepcopy a = torch.zeros(10) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 02dc4919449246..29f62c234e04d6 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1018,6 +1018,35 @@ def test_silu(self, device, dtype): input_noncontig, inplace=True), expected_output_noncontig, atol=atol, rtol=rtol) + @skipIfNoSciPy + @dtypes(torch.float, torch.double) + def test_mish(self, device, dtype): + input_np = np.random.randn(5, 8) + special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] + input_np = np.concatenate((input_np, special_input), axis=0).astype( + torch_to_numpy_dtype_dict[dtype]) + expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np))) + + expected_output = torch.from_numpy(expected_output_np).to(device) + expected_output_noncontig = expected_output.transpose(0, 1) + + atol = 1e-6 + rtol = 1e-6 + + input = torch.from_numpy(input_np).clone().contiguous().to(device) + self.assertEqual(torch.nn.functional.mish(input), expected_output, + atol=atol, rtol=rtol) + self.assertEqual(torch.nn.functional.mish(input, inplace=True), + expected_output, atol=atol, rtol=rtol) + + input = torch.from_numpy(input_np).clone().to(device) + input_noncontig = input.transpose(0, 1) + self.assertEqual(torch.nn.functional.mish(input_noncontig), + expected_output_noncontig, atol=atol, rtol=rtol) + self.assertEqual(torch.nn.functional.mish( + input_noncontig, inplace=True), expected_output_noncontig, + atol=atol, rtol=rtol) + # do ops like threshold need a test_unary(_nonufunc) test suite? @onlyCPU @dtypes(*torch.testing.get_all_math_dtypes('cpu')) diff --git a/third_party/fbgemm b/third_party/fbgemm index 7794b2950b35dd..5bbb70d35dc9b8 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 7794b2950b35ddfa7426091e7fb2f991b1407557 +Subproject commit 5bbb70d35dc9b810f426eec3848bdceb301e5afe diff --git a/third_party/gloo b/third_party/gloo index 6f7095f6e9860c..c22a5cfba94edf 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit 6f7095f6e9860ce4fd682a7894042e6eba0996f1 +Subproject commit c22a5cfba94edf8ea4f53a174d38aa0c629d070f diff --git a/third_party/nccl/nccl b/third_party/nccl/nccl index 033d799524fb97..ca8485b0d01ca6 160000 --- a/third_party/nccl/nccl +++ b/third_party/nccl/nccl @@ -1 +1 @@ -Subproject commit 033d799524fb97629af5ac2f609de367472b2696 +Subproject commit ca8485b0d01ca6dfa02f4454932011e68b461175 diff --git a/tools/actions_local_runner.py b/tools/actions_local_runner.py index 9e298138866955..369f7138d59a6b 100755 --- a/tools/actions_local_runner.py +++ b/tools/actions_local_runner.py @@ -13,7 +13,7 @@ import shlex import configparser -from typing import List, Dict, Any, Optional, Tuple, Union +from typing import List, Dict, Any, Optional, Union, NamedTuple, Set REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -87,20 +87,28 @@ def find_changed_files() -> List[str]: def print_results(job_name: str, passed: bool, streams: List[str]) -> None: - header(job_name, passed) + icon = color(col.GREEN, "✓") if passed else color(col.RED, "x") + print(f"{icon} {color(col.BLUE, job_name)}") + for stream in streams: stream = stream.strip() if stream != "": print(stream) +class CommandResult(NamedTuple): + passed: bool + stdout: str + stderr: str + + async def shell_cmd( cmd: Union[str, List[str]], env: Optional[Dict[str, Any]] = None, redirect: bool = True, -) -> Tuple[bool, str, str]: +) -> CommandResult: if isinstance(cmd, list): - cmd_str = shlex.join(cmd) # type: ignore[attr-defined] + cmd_str = " ".join(shlex.quote(arg) for arg in cmd) else: cmd_str = cmd @@ -117,140 +125,191 @@ async def shell_cmd( passed = proc.returncode == 0 if not redirect: - return passed, "", "" + return CommandResult(passed, "", "") + + return CommandResult(passed, stdout.decode().strip(), stderr.decode().strip()) + + +class Check: + name: str + + def __init__(self, files: Optional[List[str]], quiet: bool): + self.quiet = quiet + self.files = files + + async def run(self) -> bool: + result = await self.run_helper() + if result is None: + return True - return passed, stdout.decode().strip(), stderr.decode().strip() + streams = [] + if not result.passed: + streams = [ + result.stderr, + result.stdout, + ] + print_results(self.name, result.passed, streams) + return result.passed + async def run_helper(self) -> Optional[CommandResult]: + if self.files is not None: + relevant_files = self.filter_files(self.files) + if len(relevant_files) == 0: + # No files, do nothing + return CommandResult(passed=True, stdout="", stderr="") -def header(name: str, passed: bool) -> None: - PASS = color(col.GREEN, "✓") - FAIL = color(col.RED, "x") - icon = PASS if passed else FAIL - print(f"{icon} {color(col.BLUE, name)}") + return await self.quick(relevant_files) + return await self.full() -def get_flake_excludes() -> List[str]: - config = configparser.ConfigParser() - config.read(os.path.join(REPO_ROOT, ".flake8")) + def filter_ext(self, files: List[str], extensions: Set[str]) -> List[str]: + def passes(filename: str) -> bool: + return os.path.splitext(filename)[1] in extensions - excludes = re.split(r',\s*', config["flake8"]["exclude"].strip()) - excludes = [e.strip() for e in excludes if e.strip() != ""] - return excludes + return [f for f in files if passes(f)] + def filter_files(self, files: List[str]) -> List[str]: + return files -async def run_flake8(files: Optional[List[str]], quiet: bool) -> bool: - cmd = ["flake8"] + async def quick(self, files: List[str]) -> CommandResult: + raise NotImplementedError - excludes = get_flake_excludes() + async def full(self) -> Optional[CommandResult]: + raise NotImplementedError - def should_include(name: str) -> bool: - for exclude in excludes: - if fnmatch.fnmatch(name, pat=exclude): - return False - if name.startswith(exclude) or ("./" + name).startswith(exclude): - return False - return True - if files is not None: - files = [f for f in files if should_include(f)] +class Flake8(Check): + name = "flake8" - if len(files) == 0: - print_results("flake8", True, []) + def filter_files(self, files: List[str]) -> List[str]: + config = configparser.ConfigParser() + config.read(os.path.join(REPO_ROOT, ".flake8")) + + excludes = re.split(r",\s*", config["flake8"]["exclude"].strip()) + excludes = [e.strip() for e in excludes if e.strip() != ""] + + def should_include(name: str) -> bool: + for exclude in excludes: + if fnmatch.fnmatch(name, pat=exclude): + return False + if name.startswith(exclude) or f"./{name}".startswith(exclude): + return False return True - # Running quicklint, pass in an explicit list of files (unlike mypy, - # flake8 will still use .flake8 to filter this list by the 'exclude's - # in the config - cmd += files + files = self.filter_ext(files, {".py"}) + return [f for f in files if should_include(f)] - passed, stdout, stderr = await shell_cmd(cmd) - print_results("flake8", passed, [stdout, stderr]) - return passed + async def quick(self, files: List[str]) -> CommandResult: + return await shell_cmd(["flake8"] + files) + async def full(self) -> CommandResult: + return await shell_cmd(["flake8"]) -async def run_mypy(files: Optional[List[str]], quiet: bool) -> bool: - env = os.environ.copy() - if should_color(): - # Secret env variable: https://github.com/python/mypy/issues/7771 - env["MYPY_FORCE_COLOR"] = "1" - if files is not None: - # Running quick lint, use mypy-wrapper instead so it checks that the files - # actually should be linted +class Mypy(Check): + name = "mypy (skipped typestub generation)" + + def filter_files(self, files: List[str]) -> List[str]: + return self.filter_ext(files, {".py", ".pyi"}) + + def env(self) -> Dict[str, Any]: + env = os.environ.copy() + if should_color(): + # Secret env variable: https://github.com/python/mypy/issues/7771 + env["MYPY_FORCE_COLOR"] = "1" + return env - passed, stdout, stderr = await shell_cmd( - [sys.executable, "tools/mypy_wrapper.py"] + [ - os.path.join(REPO_ROOT, f) for f in files + async def quick(self, files: List[str]) -> CommandResult: + return await shell_cmd( + [sys.executable, "tools/mypy_wrapper.py"] + + [os.path.join(REPO_ROOT, f) for f in files], + env=self.env(), + ) + + async def full(self) -> None: + env = self.env() + # hackily change the name + self.name = "mypy" + + await shell_cmd( + [ + sys.executable, + "tools/actions_local_runner.py", + "--job", + "mypy", + "--file", + ".github/workflows/lint.yml", + "--step", + "Run autogen", ], + redirect=False, env=env, ) - print_results("mypy (skipped typestub generation)", passed, [ - stdout + "\n", - stderr + "\n", - ]) - return passed - - # Not running quicklint, so use lint.yml - _, _, _ = await shell_cmd( - [ - sys.executable, - "tools/actions_local_runner.py", - "--job", - "mypy", - "--file", - ".github/workflows/lint.yml", - "--step", - "Run autogen", - ], - redirect=False, - env=env, - ) - passed, _, _ = await shell_cmd( - [ - sys.executable, - "tools/actions_local_runner.py", - "--job", - "mypy", - "--file", - ".github/workflows/lint.yml", - "--step", - "Run mypy", - ], - redirect=False, - env=env, - ) - return passed + await shell_cmd( + [ + sys.executable, + "tools/actions_local_runner.py", + "--job", + "mypy", + "--file", + ".github/workflows/lint.yml", + "--step", + "Run mypy", + ], + redirect=False, + env=env, + ) -async def run_step( - step: Dict[str, Any], job_name: str, files: Optional[List[str]], quiet: bool -) -> bool: - env = os.environ.copy() - env["GITHUB_WORKSPACE"] = "/tmp" - script = step["run"] +class ShellCheck(Check): + name = "shellcheck: Run ShellCheck" + + def filter_files(self, files: List[str]) -> List[str]: + return self.filter_ext(files, {".sh"}) + + async def quick(self, files: List[str]) -> CommandResult: + return await shell_cmd( + ["tools/run_shellcheck.sh"] + [os.path.join(REPO_ROOT, f) for f in files], + ) + + async def full(self) -> None: + await shell_cmd( + [ + sys.executable, + "tools/actions_local_runner.py", + "--job", + "shellcheck", + "--file", + ".github/workflows/lint.yml", + "--step", + "Run ShellCheck", + ], + redirect=False, + ) - if quiet: - # TODO: Either lint that GHA scripts only use 'set -eux' or make this more - # resilient - script = script.replace("set -eux", "set -eu") - script = re.sub(r"^time ", "", script, flags=re.MULTILINE) - name = f'{job_name}: {step["name"]}' - passed, stderr, stdout = await shell_cmd(script, env=env) - print_results(name, passed, [stdout, stderr]) +class YamlStep(Check): + def __init__(self, step: Dict[str, Any], job_name: str, quiet: bool): + super().__init__(files=None, quiet=quiet) + self.step = step + self.name = f'{job_name}: {self.step["name"]}' - return passed + async def full(self) -> CommandResult: + env = os.environ.copy() + env["GITHUB_WORKSPACE"] = "/tmp" + script = self.step["run"] + if self.quiet: + # TODO: Either lint that GHA scripts only use 'set -eux' or make this more + # resilient + script = script.replace("set -eux", "set -eu") + script = re.sub(r"^time ", "", script, flags=re.MULTILINE) -async def run_steps( - steps: List[Dict[str, Any]], job_name: str, files: Optional[List[str]], quiet: bool -) -> bool: - coros = [run_step(step, job_name, files, quiet) for step in steps] - return all(await asyncio.gather(*coros)) + return await shell_cmd(script, env=env) -def relevant_changed_files(file_filters: Optional[List[str]]) -> Optional[List[str]]: +def changed_files() -> Optional[List[str]]: changed_files: Optional[List[str]] = None try: changed_files = sorted(find_changed_files()) @@ -262,16 +321,7 @@ def relevant_changed_files(file_filters: Optional[List[str]]) -> Optional[List[s ) return None - if file_filters is None: - return changed_files - else: - relevant_files = [] - for f in changed_files: - for file_filter in file_filters: - if f.endswith(file_filter): - relevant_files.append(f) - break - return relevant_files + return changed_files def grab_specific_steps( @@ -285,7 +335,7 @@ def grab_specific_steps( break if len(relevant_steps) != len(steps_to_grab): - raise RuntimeError("Missing steps") + raise RuntimeError(f"Missing steps:\n{relevant_steps}\n{steps_to_grab}") return relevant_steps @@ -295,11 +345,6 @@ def main() -> None: description="Pull shell scripts out of GitHub actions and run them" ) parser.add_argument("--file", help="YAML file with actions") - parser.add_argument( - "--file-filter", - help="only pass through files with this extension", - nargs="*", - ) parser.add_argument( "--changed-only", help="only run on changed files", @@ -313,12 +358,8 @@ def main() -> None: parser.add_argument("--step", action="append", help="steps to run (in order)") args = parser.parse_args() - relevant_files = None quiet = not args.no_quiet - if args.changed_only: - relevant_files = relevant_changed_files(args.file_filter) - if args.file is None: # If there is no .yml file provided, fall back to the list of known # jobs. We use this for flake8 and mypy since they run different @@ -327,7 +368,12 @@ def main() -> None: raise RuntimeError( f"Job {args.job} not found and no .yml file was provided" ) - future = ad_hoc_steps[args.job](relevant_files, quiet) + + files = None + if args.changed_only: + files = changed_files() + + checks = [ad_hoc_steps[args.job](files, quiet)] else: if args.step is None: raise RuntimeError("1+ --steps must be provided") @@ -344,20 +390,21 @@ def main() -> None: # Pull the relevant sections out of the provided .yml file and run them relevant_steps = grab_specific_steps(args.step, job) - future = run_steps(relevant_steps, args.job, relevant_files, quiet) + checks = [ + YamlStep(step=step, job_name=args.job, quiet=quiet) + for step in relevant_steps + ] - if sys.version_info >= (3, 8): - loop = asyncio.get_event_loop() - loop.run_until_complete(future) - else: - raise RuntimeError("Only Python >=3.8 is supported") + loop = asyncio.get_event_loop() + loop.run_until_complete(asyncio.gather(*[check.run() for check in checks])) # These are run differently locally in order to enable quicklint, so dispatch # out to special handlers instead of using lint.yml ad_hoc_steps = { - "mypy": run_mypy, - "flake8-py3": run_flake8, + "mypy": Mypy, + "flake8-py3": Flake8, + "shellcheck": ShellCheck, } if __name__ == "__main__": diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 56f9ffc491d873..b0a056bf08d08f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -182,6 +182,7 @@ - name: acos(Tensor self) -> Tensor self: grad * -((-self * self + 1).rsqrt()).conj() + result: auto_element_wise - name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) @@ -190,26 +191,31 @@ - name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) + result: self_t - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta.conj()) batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2).conj()) * alpha.conj() batch2: batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha.conj() + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha) - name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) + result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value) - name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj()) tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj()) + result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value) - name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta.conj()) mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha) mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha) - name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta) @@ -220,20 +226,24 @@ self: maybe_multiply(grad, beta.conj()) mat: grad.ger(vec.conj()) * alpha.conj() vec: mat.t().conj().mv(grad) * alpha.conj() + result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha) - name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta.conj()) vec1: grad.mv(vec2.conj()) * alpha.conj() vec2: grad.t().mv(vec1.conj()) * alpha.conj() + result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha) - name: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor theta: affine_grid_generator_backward(grad, size, align_corners) - name: alias(Tensor(a) self) -> Tensor(a) self: grad + result: self_t - name: angle(Tensor self) -> Tensor self: angle_backward(grad, self) + result: handle_r_to_c(result.scalar_type(), angle_backward(self_t, self_p)) # The four items below are necessary because TensorIterator doesn't work on # Variables (codegen does not unwrap the input Tensor for all() and any() ). @@ -251,18 +261,21 @@ - name: acosh(Tensor self) -> Tensor self: grad * (self.pow(2) - 1).rsqrt().conj() + result: auto_element_wise - name: acosh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of acosh") - name: asinh(Tensor self) -> Tensor self: grad * (self.pow(2) + 1).rsqrt().conj() + result: auto_element_wise - name: asinh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of asinh") - name: atanh(Tensor self) -> Tensor self: grad * 1 / (1 - self.pow(2)).conj() + result: auto_element_wise - name: atanh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of atanh") @@ -272,9 +285,11 @@ - name: asin(Tensor self) -> Tensor self: grad * (-self * self + 1).rsqrt().conj() + result: auto_element_wise - name: atan(Tensor self) -> Tensor self: grad / (self * self + 1).conj() + result: auto_element_wise - name: atan2(Tensor self, Tensor other) -> Tensor self, other: atan2_backward(grad, self, other, grad_input_mask) @@ -362,6 +377,7 @@ - name: _conj(Tensor self) -> Tensor self: grad.conj() + result: self_t.conj() - name: copysign.Tensor(Tensor self, Tensor other) -> Tensor self: copysign_tensor_self_backward(grad, self, result) @@ -1154,7 +1170,7 @@ - name: t(Tensor(a) self) -> Tensor(a) self: grad.t() -- name: one_hot(Tensor self, int num_classes=-1) -> Tensor +- name: one_hot(Tensor self, int num_classes=-1, ScalarType dtype=long) -> Tensor self: non_differentiable - name: flip(Tensor self, int[] dims) -> Tensor @@ -1207,6 +1223,9 @@ - name: to_sparse(Tensor self) -> Tensor self: grad.to_dense() +- name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + self: grad.to_dense() + - name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor self: to_mkldnn_backward(grad, self) @@ -1397,6 +1416,9 @@ - name: silu(Tensor self) -> Tensor self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" +- name: mish(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" + - name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index f93c6456da59c2..4f1cb242ac8bdf 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -240,8 +240,6 @@ BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), - BaseCType(stringT): (GETTER_DEFINITION, GETTER_BODY_STRING), - OptionalCType(BaseCType(stringT)): (GETTER_DEFINITION_OPT, GETTER_BODY_STRING), BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), } @@ -369,6 +367,14 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: saved_variables.append(f'{type.cpp_type()} {name} = 0;') getter_definitions.append(GETTER_DEFINITION.substitute( op=info.op, name=name, body=GETTER_BODY_INT64_T)) + elif type == BaseCType(stringT): + saved_variables.append(f'std::string {name};') + getter_definitions.append(GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_STRING)) + elif type == OptionalCType(BaseCType(stringT)): + saved_variables.append(f'c10::optional {name};') + getter_definitions.append(GETTER_DEFINITION_OPT.substitute( + op=info.op, name=name, body=GETTER_BODY_STRING)) else: saved_variables.append(f'{type.cpp_type()} {name};') diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 5be8d9c02ce1bb..3862e6f596159d 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -35,7 +35,7 @@ from tools.codegen.api.types import (Binding, DispatcherSignature, BaseCType, intArrayRefT, tensorT, tensorListT, MutRefCType, OptionalCType, - ListCType, SpecialArgName, scalarT) + ListCType, SpecialArgName, scalarT, stringT) from tools.codegen.api.autograd import ( DifferentiableInput, NativeFunctionWithDifferentiabilityInfo, SavedAttribute, dispatch_strategy, gen_differentiable_outputs, @@ -610,6 +610,10 @@ def save_variables( name += '_' elif type == BaseCType(intArrayRefT): expr = expr + ".vec()" + elif type == BaseCType(stringT): + expr = f'std::string({expr})' + elif type == OptionalCType(BaseCType(stringT)): + expr = f'{expr}.has_value() ? c10::optional(std::string({expr}.value())) : c10::nullopt' guard = guard_for(arg) if guard is None: stmts.append(f'grad_fn->{name} = {expr};') @@ -748,10 +752,10 @@ def emit_fw_derivatives() -> List[str]: # Handle functions like stack # For these, we don't unpack anything and always call the user function if not (len(differentiable_inputs) == 1 and is_tensor_list_type(differentiable_inputs[0].type)): - raise RuntimeError(f'No differentiable input to "{name}" is a differentiable Tensor even though a ' - 'forward gradient formula has been defined for it. This case should only happen ' - 'for function that take a single TensorList as input. All other cases are not ' - 'supported right now.') + raise RuntimeError(f'No differentiable input to "{name}" is a differentiable Tensor (as the provided' + 'forward AD formula does not use any input tangent) even though a forward gradient ' + 'formula has been defined for it. This case should only happen for function that ' + 'take a single TensorList as input. All other cases are not supported right now.') requires_fw_grad = "true" unpacked_arguments = "" for inp in differentiable_inputs: diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 060f2eaefa83dc..821ffbdf818760 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -11,7 +11,8 @@ SavedAttribute, ForwardDerivative) from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType, intArrayRefT, tensorOptionsT, typeAndSizeT, intT, - tensorGeometryT, scalarTypeT, SpecialArgName) + tensorGeometryT, scalarTypeT, SpecialArgName, + OptionalCType, stringT) from tools.codegen.api import cpp from tools.codegen.gen import parse_native_yaml from tools.codegen.context import with_native_function @@ -169,21 +170,23 @@ def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: "forward definition of gradient as element_wise but it does not " "defines the gradient formula for its argument which is required.") # This transformation is based on the observation that for element-wise functions, the Jacobian - # matrix is diagonal and thus doing J * v or v * J gives the same result. + # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) + # For the complex case, we use hermitian transpose and get (v.conj() J).conj() # So here we are going to re-use the backward formula and replace two things: - # 1) all occurrences of "grad" with "foo_t", where foo is the name of the unique differentiable input. + # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. # 2) all usage of an original input "foo" with its primal value "foo_p". + # 3) conjugate the final result # For example, for abs, the backward formula is: # grad * self.sgn() # And this function generates a forward formula that is: - # self_t * self_p.sgn() + # (self_t.conj() * self_p.sgn()).conj() backward_formula = derivatives[0].original_formula input_name = args_with_derivatives[0].name # Do replacement 1) of the grad def repl(m: Any) -> str: - return f"{m.group(1)}{input_name}_t{m.group(2)}" + return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) # Do replacement 2) of the input variables @@ -194,6 +197,9 @@ def repl(m: Any) -> str: return f"{m.group(1)}{arg_name}_p{m.group(2)}" fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) + # Do the final conjugate 3) + fw_formula = f"({fw_formula}).conj()" + # Since there is a single differentiable inputs and we necessarily need its tangent we can # simply require all differentiable input's tangent. required_inputs_tangent = tuple(all_arg_names) @@ -516,6 +522,15 @@ def repl(m: Match[str]) -> str: formula = re.sub(regex.format(name), repl, formula) + # c10::optional types stored in Backward nodes must be + # converted to c10::optional before being passed into + # the backward function + if nctype.type == OptionalCType(BaseCType(stringT)): + formula = re.sub( + rf'\b{name}\b', + f'{name}.has_value() ? c10::optional({name}.value()) : c10::nullopt', + formula) + # Find any variables which remain in the formula and save them if re.search(IDENT_REGEX.format(name), formula): saved.append(SavedAttribute( diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 87068200b8dc2a..5168d211694bfc 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -138,9 +138,11 @@ core_sources_full_mobile = [ "torch/csrc/jit/api/module.cpp", "torch/csrc/jit/api/object.cpp", "torch/csrc/jit/backends/backend_debug_handler.cpp", + "torch/csrc/jit/backends/backend_debug_info.cpp", "torch/csrc/jit/backends/backend_detail.cpp", "torch/csrc/jit/backends/backend_interface.cpp", "torch/csrc/jit/backends/backend_resolver.cpp", + "torch/csrc/jit/backends/generate_debug_handles.cpp", "torch/csrc/jit/codegen/fuser/codegen.cpp", "torch/csrc/jit/codegen/fuser/compiler.cpp", "torch/csrc/jit/codegen/fuser/executor.cpp", @@ -207,9 +209,11 @@ core_sources_full_mobile = [ "torch/csrc/jit/passes/lower_tuples.cpp", "torch/csrc/jit/passes/normalize_ops.cpp", "torch/csrc/jit/passes/peephole_list_idioms.cpp", + "torch/csrc/jit/passes/value_refinement_utils.cpp", "torch/csrc/jit/passes/peephole_alias_sensitive.cpp", "torch/csrc/jit/passes/pass_manager.cpp", "torch/csrc/jit/passes/peephole.cpp", + "torch/csrc/jit/passes/peephole_non_tensor.cpp", "torch/csrc/jit/passes/create_functional_graphs.cpp", "torch/csrc/jit/passes/remove_mutation.cpp", "torch/csrc/jit/passes/prepack_folding.cpp", @@ -222,6 +226,8 @@ core_sources_full_mobile = [ "torch/csrc/jit/passes/remove_dropout.cpp", "torch/csrc/jit/passes/requires_grad_analysis.cpp", "torch/csrc/jit/passes/shape_analysis.cpp", + "torch/csrc/jit/passes/integer_value_refinement.cpp", + "torch/csrc/jit/passes/symbolic_shape_analysis.cpp", "torch/csrc/jit/passes/specialize_autogradzero.cpp", "torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp", "torch/csrc/jit/passes/subgraph_rewrite.cpp", @@ -248,6 +254,7 @@ core_sources_full_mobile = [ "torch/csrc/jit/runtime/logging.cpp", "torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp", "torch/csrc/jit/runtime/profiling_record.cpp", + "torch/csrc/jit/runtime/script_profile.cpp", "torch/csrc/jit/runtime/symbolic_script.cpp", "torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp", "torch/csrc/jit/serialization/import.cpp", @@ -354,6 +361,11 @@ jit_sources_full = [ libtorch_core_jit_sources = sorted(jit_sources_full) torch_mobile_core = [ + # backend_debug_info.cpp provides + # __torch__.torch.classes.backend.BackendDebugInfo class + # This should not be needed eventually. + # TODO: Remove this dependency + "torch/csrc/jit/backends/backend_debug_info.cpp", "torch/csrc/jit/mobile/function.cpp", "torch/csrc/jit/mobile/import.cpp", "torch/csrc/jit/mobile/interpreter.cpp", @@ -782,7 +794,7 @@ aten_cpu_source_codegen_list = [ "aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp", ] -# When buliding lite interpreter in OSS, "aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp" will go through +# When building lite interpreter in OSS, "aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp" will go through # codegen process. The codegen version of this file, like Activation.cpp.DEFAULT.cpp, will be included # in ${cpu_kernel_cpp} in aten/src/ATen/CMakeLists.txt. As a result, in aten/src/ATen/CMakeLists.txt, # only aten_cpu_source_non_codegen_list need to be added to ${all_cpu_cpp}. diff --git a/tools/code_analyzer/default_op_deps.yaml b/tools/code_analyzer/default_op_deps.yaml index 3f100af1685442..9b71fbf0e650cb 100644 --- a/tools/code_analyzer/default_op_deps.yaml +++ b/tools/code_analyzer/default_op_deps.yaml @@ -7433,6 +7433,45 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::mish + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_meta + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::mish + - name: aten::to +- name: aten::mish_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mish +- name: aten::mish_backward + depends: + - name: aten::add + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_meta + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sigmoid + - name: aten::softplus + - name: aten::sub_ + - name: aten::to + - name: aten::tanh - name: aten::mkldnn_adaptive_avg_pool2d depends: - name: aten::eq diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index d66030c4c8162a..8eb2f3696f517e 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -590,7 +590,7 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: elif t.name == BaseTy.float: return 'double' elif t.name == BaseTy.str: - return 'std::string' + return 'c10::string_view' elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage, BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat, @@ -798,7 +798,7 @@ def argument_type_str_pyi(t: Type) -> str: elif t.name == BaseTy.Dimname: ret = 'Union[str, ellipsis, None]' elif t.name in [BaseTy.Tensor, BaseTy.Generator, - BaseTy.Storage, BaseTy.Stream, BaseTy.str]: + BaseTy.Storage, BaseTy.Stream]: # These python schema type names line up with their function schema names ret = t.name.name @@ -1016,7 +1016,7 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str: elif t.name == BaseTy.float: return 'toDouble' elif t.name == BaseTy.str: - return 'string' + return 'stringView' elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index 07171386cd4069..850b6a50df285f 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -34,7 +34,7 @@ def __str__(self) -> str: doubleT = BaseCppType('', 'double') boolT = BaseCppType('', 'bool') voidT = BaseCppType('', 'void') -stringT = BaseCppType('std', 'string') +stringT = BaseCppType('c10', 'string_view') generatorT = BaseCppType('at', 'Generator') scalarTypeT = BaseCppType('at', 'ScalarType') tensorT = BaseCppType('at', 'Tensor') diff --git a/tools/codegen/dest/gen_external_aten_fallbacks.py b/tools/codegen/dest/gen_external_aten_fallbacks.py index 32904f17394a08..62fdd800b39e29 100644 --- a/tools/codegen/dest/gen_external_aten_fallbacks.py +++ b/tools/codegen/dest/gen_external_aten_fallbacks.py @@ -73,11 +73,11 @@ def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: def requires_backend_wrapper(f: NativeFunction, backend_index: BackendIndex) -> bool: requires_lowering = not f.has_composite_kernel and not has_autogenerated_composite_kernel(f) - has_xla_lowering = backend_index.has_kernel(f) + has_backend_kernel = backend_index.has_kernel(f) in_denylist = any([re.match(frx, str(f.func.name)) for frx in _FN_DENYLIST_REGEX]) - return not in_denylist and (requires_lowering or has_xla_lowering) + return not in_denylist and (requires_lowering or has_backend_kernel) -def xla_tensor_creation_api( +def tensor_creation_api( ret_name: str, ret: Return, device_param_name: str, @@ -85,18 +85,16 @@ def xla_tensor_creation_api( cpu_result_name: str, tuple_idx: Optional[int] = None ) -> str: - if ret.type == BaseType(BaseTy.Tensor) and not ret.is_write: - # Only raw Tensor (non-reference) returns need to go through the XLA tensor creation API. - # Tensor references can be returned directly, since they've already been converted to XLA tensors. + if (ret.type == BaseType(BaseTy.Tensor) and not ret.is_write) or \ + (isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor)): + # Only raw Tensor (non-reference) returns need to be copied back from CPU to the backend device. + # Tensor references can be returned directly, since they already live on the backend device. # See Note [Tensor Copy Returns] - pass - elif isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor): - pass + return f"to_device_opt({cpu_result_name}, get_device_arg({device_param_name}))" else: - # for non tensor-types, there's no need to wrap the output in an xla bridge api. + # for non tensor-types, we don't need to convert between devices. return ret_name - return f"to_device_opt({cpu_result_name}, get_device_arg({device_param_name}))" @@ -181,16 +179,15 @@ def get_device_param(args: List[Argument]) -> str: opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor)] - opt_tensor_args: Dict[Argument, str] = {a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors)} + opt_tensor_args: Dict[Argument, str] = {a: f'external_tensors_opt[{i}]' for i, a in enumerate(opt_tensors)} tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)] - tensor_args: Dict[Argument, str] = {a: f'xlatens[{i}]' for i, a in enumerate(tensors)} + tensor_args: Dict[Argument, str] = {a: f'external_tensors[{i}]' for i, a in enumerate(tensors)} annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write] print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys()]) - tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([f' auto {updated_name} = to_cpu({arg.name});' @@ -199,14 +196,17 @@ def get_device_param(args: List[Argument]) -> str: opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) - opt_tensor_intermediates_str = f'\n std::vector> xlatens_opt_tensors = {{{arg_str}}};' - opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' + opt_tensor_intermediates_str = \ + f'\n std::vector> external_tensors_opt_tensors = {{{arg_str}}};' + opt_tensor_intermediates_str += \ + '\n auto external_tensors_opt = to_cpu(external_tensors_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' - intermediates += f" std::vector xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" - intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" + intermediates += \ + f" std::vector external_tensors_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" + intermediates += "\n auto external_tensors = to_cpu(external_tensors_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str @@ -238,12 +238,11 @@ def get_device_param(args: List[Argument]) -> str: update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join([str(i) for i in annotated_tensor_indices]) - collect_mutated_tensors = f'\n std::vector xlatens_update_indices = {{{indices_str}}};' + collect_mutated_tensors = f'\n std::vector external_tensors_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' - for (int i : xlatens_update_indices) { - // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); - at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); + for (int i : external_tensors_update_indices) { + at::_copy_from_and_resize(external_tensors[i], external_tensors_tensors[i]); } ''' @@ -251,12 +250,12 @@ def get_device_param(args: List[Argument]) -> str: if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: - returns = xla_tensor_creation_api( + returns = tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ - xla_tensor_creation_api( + tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns))] diff --git a/tools/codegen/dest/native_functions.py b/tools/codegen/dest/native_functions.py index b86a3666c6de7f..c643ea0eebe015 100644 --- a/tools/codegen/dest/native_functions.py +++ b/tools/codegen/dest/native_functions.py @@ -16,8 +16,8 @@ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional if "legacy::" in metadata.kernel: return None else: - prefix = 'static' if backend_index.external else 'TORCH_API' - return f"{prefix} {sig.decl(name=metadata.kernel)};" + prefix = '' if backend_index.external else 'TORCH_API ' + return f"{prefix}{sig.decl(name=metadata.kernel)};" @with_native_function_and_index def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: @@ -26,9 +26,9 @@ def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List metadata = backend_index.get_kernel(g) if metadata is None: return [] - prefix = 'static' if backend_index.external else 'TORCH_API' + prefix = '' if backend_index.external else 'TORCH_API ' return [f"""\ -struct {prefix} structured_{metadata.kernel} : public at::meta::{meta_name} {{ +struct {prefix}structured_{metadata.kernel} : public at::meta::{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """] diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index 62c2c6bfee33c0..a2f139594dc579 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -9,8 +9,7 @@ from tools.codegen.model import (DispatchKey, NativeFunction, NativeFunctionsGroup, SchemaKind, TensorOptionsArguments, - DeviceCheckType, Argument, - assert_never, BaseType, BaseTy, + DeviceCheckType, Argument, assert_never, is_cuda_dispatch_key, BackendIndex, gets_generated_out_inplace_wrapper) from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType, @@ -20,7 +19,6 @@ DispatcherSignature) import tools.codegen.api.meta as meta import tools.codegen.api.cpp as cpp -import tools.codegen.api.dispatcher as dispatcher import tools.codegen.api.structured as structured from tools.codegen.api.translate import translate from tools.codegen.selective_build.selector import SelectiveBuilder @@ -111,11 +109,6 @@ def gen_out_inplace_wrapper(self, f: NativeFunction, g: Optional[NativeFunctions sig = self.wrapper_kernel_sig(f) name = sig.name() - # See Note [External Backends Follow Dispatcher convention] - jit_args = dispatcher.jit_arguments(f.func) - tensors = [a for a in jit_args if isinstance(a, Argument) and a.type == BaseType(BaseTy.Tensor)] - print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors]) - func_res = f'{name}_tmp' return_names = cpp.return_names(f) if len(return_names) > 1: @@ -132,8 +125,6 @@ def gen_out_inplace_wrapper(self, f: NativeFunction, g: Optional[NativeFunctions return f"""\ {sig.defn()} {{ - XLA_FN_TRACK(3); - TF_VLOG(3) << "XLA {name} :"{print_args_str}; auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))}); {updates} return {returns}; @@ -235,11 +226,7 @@ def generate_defn(cpp_sig: CppSignature) -> str: metadata = self.backend_index.get_kernel(f) if metadata is None: return None - # TODO: remove this difference and merge the two cases when we remove xla-specific logic - if self.backend_index.external: - impl_name = f"{self.cpp_namespace}::AtenXlaType::{metadata.kernel}" - else: - impl_name = f"{self.cpp_namespace}::{metadata.kernel}" + impl_name = f"{self.cpp_namespace}::{metadata.kernel}" args_exprs_str = ', '.join(a.name for a in args) diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index af2cab42d1f7f2..9ce4ebcafcf7de 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -1026,7 +1026,7 @@ def make_file_manager(install_dir: str) -> FileManager: 'native_function_declarations': list(concatMap( # Convert to a set first to remove duplicate kernel names. # Backends are allowed to repeat kernel names; only generate the declaration once! - lambda f: list(set(concatMap( + lambda f: list(OrderedDict.fromkeys(concatMap( lambda backend_idx: dest.compute_native_function_declaration(f, backend_idx), backend_indices.values()))), diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py index 7b65b0db64cf4e..2602570b0cd4dd 100644 --- a/tools/codegen/gen_backend_stubs.py +++ b/tools/codegen/gen_backend_stubs.py @@ -21,7 +21,8 @@ # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping) -ParsedExternalYaml = namedtuple('ParsedExternalYaml', ['backend_key', 'autograd_key', 'cpp_namespace', 'backend_indices']) +ParsedExternalYaml = namedtuple('ParsedExternalYaml', [ + 'backend_key', 'autograd_key', 'cpp_namespace', 'backend_indices']) def parse_backend_yaml( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], @@ -37,7 +38,7 @@ def parse_backend_yaml( yaml_values = yaml.load(f, Loader=Loader) assert isinstance(yaml_values, dict) - valid_keys = ['backend', 'cpp_namespace', 'supported', 'autograd'] + valid_keys = ['backend', 'cpp_namespace', 'extra_headers', 'supported', 'autograd'] backend = yaml_values.pop('backend', None) assert backend is not None, 'You must provide a value for "backend"' @@ -69,7 +70,11 @@ def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey) -> B metadata[op_name] = m # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops, # this should eventually be toggleable per-backend. - return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=False, external=True, index=metadata) + return BackendIndex( + dispatch_key=dispatch_key, + use_out_as_primary=False, + external=True, + index=metadata) backend_key: Optional[DispatchKey] = None if len(supported) > 0: @@ -156,12 +161,12 @@ def make_file_manager(install_dir: str) -> FileManager: backend_dispatch_key: DispatchKey = backend_key autograd_dispatch_key: DispatchKey = autograd_key generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!' - fm.write('aten_xla_type.h', lambda: { + fm.write_with_template(f'{backend_dispatch_key}NativeFunctions.h', 'DispatchKeyNativeFunctions.h', lambda: { 'generated_comment': generated_comment, 'cpp_namespace': cpp_namespace, # Convert to a set first to remove duplicate kernel names. # Backends are allowed to repeat kernel names; only generate the declaration once! - 'dispatch_xla_declarations': list(set(concatMap( + 'dispatch_declarations': list(set(concatMap( lambda f: dest.compute_native_function_declaration(f, backend_indices[backend_dispatch_key]), grouped_native_functions ))) + list(set(concatMap( @@ -170,19 +175,12 @@ def make_file_manager(install_dir: str) -> FileManager: ))), }) - external_backend_headers = '''\ -#include -#include -#include -#include -#include -#include ''' - for dispatch_key in [backend_dispatch_key, autograd_dispatch_key]: fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', 'legacy_th_headers': '', - 'external_backend_headers': external_backend_headers, + 'external_backend_headers': f'''#include "{output_dir}/{backend_key}NativeFunctions.h" +#include ''', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_namespaced_definitions': list(concatMap( diff --git a/tools/print_test_stats.py b/tools/print_test_stats.py index 7d7d27f2c4b836..7875f9646a4565 100755 --- a/tools/print_test_stats.py +++ b/tools/print_test_stats.py @@ -951,4 +951,7 @@ def reports_has_no_tests(reports: Dict[str, TestFile]) -> bool: head_json = obj if args.use_json: head_json = json.loads(Path(args.use_json).read_text()) - print_regressions(head_json, num_prev_commits=args.num_prev_commits) + try: + print_regressions(head_json, num_prev_commits=args.num_prev_commits) + except Exception as e: + print(f"ERROR ENCOUNTERED WHEN COMPARING AGAINST S3: {e}") diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 9ce09c46ce7abd..2309ad3bdc52a5 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -349,4 +349,6 @@ def build(self, my_env): build_args += ['--', '/p:CL_MPCount={}'.format(max_jobs)] else: build_args += ['--', '-j', max_jobs] + if USE_NINJA: + build_args += ['-v'] self.run(build_args, my_env) diff --git a/tools/test/test_actions_local_runner.py b/tools/test/test_actions_local_runner.py index 874e7c5bbe6a89..ba4e6fd2cdb923 100644 --- a/tools/test/test_actions_local_runner.py +++ b/tools/test/test_actions_local_runner.py @@ -24,46 +24,29 @@ def run(self, *args: List[Any], **kwargs: List[Dict[str, Any]]) -> Any: def test_step_extraction(self) -> None: fake_job = { "steps": [ - { - "name": "test1", - "run": "echo hi" - }, - { - "name": "test2", - "run": "echo hi" - }, - { - "name": "test3", - "run": "echo hi" - }, + {"name": "test1", "run": "echo hi"}, + {"name": "test2", "run": "echo hi"}, + {"name": "test3", "run": "echo hi"}, ] } actual = actions_local_runner.grab_specific_steps(["test2"], fake_job) expected = [ - { - "name": "test2", - "run": "echo hi" - }, + {"name": "test2", "run": "echo hi"}, ] self.assertEqual(actual, expected) async def test_runner(self) -> None: - fake_step = { - "name": "say hello", - "run": "echo hi" - } + fake_step = {"name": "say hello", "run": "echo hi"} f = io.StringIO() with contextlib.redirect_stdout(f): - await actions_local_runner.run_steps([fake_step], "test", None, True) + await actions_local_runner.YamlStep(fake_step, "test", True).run() result = f.getvalue() self.assertIn("say hello", result) - self.assertIn("hi", result) class TestEndToEnd(unittest.TestCase): expected = [ - "quick-checks: Extract scripts from GitHub Actions workflows", "cmakelint: Run cmakelint", "quick-checks: Ensure no direct cub include", "quick-checks: Ensure no unqualified type ignore", @@ -74,12 +57,17 @@ class TestEndToEnd(unittest.TestCase): "flake8", "quick-checks: Ensure correct trailing newlines", "quick-checks: Ensure no trailing spaces", - "quick-checks: Run ShellCheck", + "shellcheck: Regenerate workflows", + "shellcheck: Assert that regenerating the workflows didn't change them", + "shellcheck: Extract scripts from GitHub Actions workflows", + "shellcheck: Run ShellCheck", ] def test_lint(self): cmd = ["make", "lint", "-j", str(multiprocessing.cpu_count())] - proc = subprocess.run(cmd, cwd=actions_local_runner.REPO_ROOT, stdout=subprocess.PIPE) + proc = subprocess.run( + cmd, cwd=actions_local_runner.REPO_ROOT, stdout=subprocess.PIPE + ) stdout = proc.stdout.decode() for line in self.expected: @@ -89,7 +77,9 @@ def test_lint(self): def test_quicklint(self): cmd = ["make", "quicklint", "-j", str(multiprocessing.cpu_count())] - proc = subprocess.run(cmd, cwd=actions_local_runner.REPO_ROOT, stdout=subprocess.PIPE) + proc = subprocess.run( + cmd, cwd=actions_local_runner.REPO_ROOT, stdout=subprocess.PIPE + ) stdout = proc.stdout.decode() for line in self.expected: @@ -98,23 +88,29 @@ def test_quicklint(self): # TODO: See https://github.com/pytorch/pytorch/issues/57967 self.assertIn("mypy (skipped typestub generation)", stdout) - class TestQuicklint(unittest.IsolatedAsyncioTestCase): test_files = [ os.path.join("caffe2", "some_cool_file.py"), os.path.join("torch", "some_cool_file.py"), os.path.join("aten", "some_cool_file.py"), os.path.join("torch", "some_stubs.pyi"), + os.path.join("test.sh"), ] + test_py_files = [ + f for f in test_files if f.endswith(".py") or f.endswith(".pyi") + ] + test_sh_files = [f for f in test_files if f.endswith(".sh")] maxDiff = None def setUp(self, *args, **kwargs): for name in self.test_files: - bad_code = textwrap.dedent(""" + bad_code = textwrap.dedent( + """ some_variable = '2' some_variable = None some_variable = 11.2 - """).rstrip("\n") + """ + ).rstrip("\n") with open(name, "w") as f: f.write(bad_code) @@ -131,16 +127,26 @@ def test_file_selection(self): async def test_flake8(self): f = io.StringIO() with contextlib.redirect_stdout(f): - await actions_local_runner.run_flake8(self.test_files, True) + await actions_local_runner.Flake8(self.test_py_files, True).run() # Should exclude the caffe2/ file - expected = textwrap.dedent(""" + expected = textwrap.dedent( + """ x flake8 torch/some_cool_file.py:4:21: W292 no newline at end of file aten/some_cool_file.py:4:21: W292 no newline at end of file - """).lstrip("\n") + """ + ).lstrip("\n") self.assertEqual(expected, f.getvalue()) + async def test_shellcheck(self): + f = io.StringIO() + with contextlib.redirect_stdout(f): + await actions_local_runner.ShellCheck(self.test_sh_files, True).run() + + self.assertIn("SC2148: Tips depend on target shell", f.getvalue()) + self.assertIn("SC2283: Remove spaces around = to assign", f.getvalue()) + async def test_mypy(self): self.maxDiff = None f = io.StringIO() @@ -161,12 +167,12 @@ async def test_mypy(self): redirect=True, ) - await actions_local_runner.run_mypy(self.test_files, True) - + await actions_local_runner.Mypy(self.test_py_files, True).run() # Should exclude the aten/ file; also, apparently mypy # typechecks files in reverse order - expected = textwrap.dedent(""" + expected = textwrap.dedent( + """ x mypy (skipped typestub generation) torch/some_stubs.pyi:3:17: error: Incompatible types in assignment (expression has type "None", variable has type "str") [assignment] torch/some_stubs.pyi:4:17: error: Incompatible types in assignment (expression has type "float", variable has type "str") [assignment] @@ -174,9 +180,12 @@ async def test_mypy(self): torch/some_cool_file.py:4:17: error: Incompatible types in assignment (expression has type "float", variable has type "str") [assignment] caffe2/some_cool_file.py:3:17: error: Incompatible types in assignment (expression has type "None", variable has type "str") [assignment] caffe2/some_cool_file.py:4:17: error: Incompatible types in assignment (expression has type "float", variable has type "str") [assignment] - """).lstrip("\n") # noqa: B950 + """ # noqa: B950 + ).lstrip( + "\n" + ) self.assertEqual(expected, f.getvalue()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2fc064d325aea3..fe191bec9ebec6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -79,7 +79,7 @@ class iinfo: def __init__(self, dtype: _dtype) -> None: ... class finfo: - bits: _float + bits: _int min: _float max: _float eps: _float @@ -260,6 +260,8 @@ def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Unio def _backport_for_mobile_from_buffer(buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int) -> None: ... def _backport_for_mobile_to_buffer(filename_input: Union[str, Path], to_version: _int) -> bytes:... def _backport_for_mobile_from_buffer_to_buffer(buffer: BinaryIO, to_version: _int) -> bytes:... +def _get_model_ops_and_info(filename: Union[str, Path]): ... +def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ... def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ... def _get_graph_executor_optimize() -> _bool: ... def _set_graph_executor_optimize(optimize: _bool): ... @@ -329,8 +331,9 @@ def _jit_pass_onnx_block( old_block: Block, new_block: Block, operator_export_type: _onnx.OperatorExportTypes, - env: Dict[Value, Value] -) -> None: ... + env: Dict[Value, Value], + is_sub_block: _bool +) -> Dict[Value, Value]: ... def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ... def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ... @@ -612,6 +615,10 @@ def is_inference_mode_enabled() -> _bool: ... def set_autocast_enabled(enabled: _bool) -> None: ... def is_autocast_enabled() -> _bool: ... def clear_autocast_cache() -> None: ... +def set_autocast_cpu_enabled(enabled: _bool) -> None: ... +def is_autocast_cpu_enabled() -> _bool: ... +def set_autocast_cpu_dtype(dtype: _dtype) -> None: ... +def get_autocast_cpu_dtype() -> _dtype: ... def autocast_increment_nesting() -> _int: ... def autocast_decrement_nesting() -> _int: ... def set_anomaly_enabled(enabled: _bool) -> None: ... @@ -958,7 +965,7 @@ class DictType(JitType): def getValueType(self) -> JitType: ... class TupleType(JitType): - def __init__(self, a: List[JitType]) -> None: ... + def __init__(self, a: List[Optional[JitType]]) -> None: ... def elements(self) -> List[JitType]: ... class ClassType(JitType): diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 4f98958b3f53e2..845466b8bc6136 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -346,6 +346,13 @@ class ProcessGroupGloo(ProcessGroup): def create_default_device() -> Device: ... ... +class _ProcessGroupWrapper(ProcessGroup): + def __init__( + self, + pg: ProcessGroup, + gloo_pg: ProcessGroupGloo + ): ... + class ProcessGroupNCCL(ProcessGroup): class Options: ... def __init__( diff --git a/torch/__init__.py b/torch/__init__.py index e6fc1ee76468da..837d2fed53bd49 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -376,6 +376,8 @@ def use_deterministic_algorithms(mode): tensor * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU tensor + * :func:`torch.Tensor.scatter_add_` when ``input`` dimension is one and called + on a CUDA tensor * :func:`torch.gather` when ``input`` dimension is one and called on a CUDA tensor that requires grad * :func:`torch.index_add` when called on CUDA tensor @@ -410,15 +412,16 @@ def use_deterministic_algorithms(mode): * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when ``mode='max'`` - * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor + * :func:`torch.Tensor.scatter_add_` when ``input`` dimension is larger than one + and called on a CUDA tensor + * :func:`torch.gather` when ``input`` dimension is larger than one + and called on a CUDA tensor that requires grad * :func:`torch.Tensor.put_` when ``accumulate=False`` * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor * :func:`torch.histc` when called on a CUDA tensor * :func:`torch.bincount` when called on a CUDA tensor * :func:`torch.kthvalue` with called on a CUDA tensor * :func:`torch.median` with indices output when called on a CUDA tensor - * :func:`torch.gather` when ``input`` dimension is larger than one - and called on a CUDA tensor that requires grad * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor A handful of CUDA operations are nondeterministic if the CUDA version is @@ -670,6 +673,7 @@ def _assert(condition, message): # side effect of adding to the imported module's members for other users. from torch import cuda as cuda +from torch import cpu as cpu from torch import autograd as autograd from torch.autograd import ( no_grad as no_grad, diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 748c40693419e7..94a72b5553b299 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -17,6 +17,7 @@ import builtins import io import pickle +import functools # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. # Explicitly ask to import `torch.distributed.__init__` first. # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. @@ -60,6 +61,11 @@ def parseNestedExpr(expr, module) -> Tuple[Any, int]: while i < len(expr) and expr[i] not in (',', '[', ']'): i += 1 + # Special case logic for the empty Tuple as a subscript (used + # in the type annotation `Tuple[()]`) + if expr[:i] == '()': + return (), i + base = lookupInModule(expr[:i].strip(), module) assert base is not None, f"Unresolvable type {expr[:i]}" if i == len(expr) or expr[i] != '[': @@ -971,6 +977,9 @@ def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_t self.uses_true_division = uses_true_division self.filename = filename +@functools.lru_cache(maxsize=None) +def make_source_context(*args): + return SourceContext(*args) def fake_range(): return SourceContext('', None, 0, 0).make_raw_range(0, 1) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 192dd0336d1348..1fc44d334440fb 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2942,7 +2942,7 @@ def callable(a, b) -> number add_docstr_all('repeat_interleave', r""" -repeat_interleave(repeats, dim=None) -> Tensor +repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor See :func:`torch.repeat_interleave`. """) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index eeea3e35f4a13a..8cbfcbaf84f44f 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -10404,7 +10404,7 @@ def merge_dicts(*dicts): add_docstr(torch.repeat_interleave, r""" -repeat_interleave(input, repeats, dim=None) -> Tensor +repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor Repeat elements of a tensor. @@ -10420,6 +10420,11 @@ def merge_dicts(*dicts): By default, use the flattened input array, and return a flat output array. +Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream syncronization + needed to calculate output shape of the tensor. + Returns: Tensor: Repeated tensor which has the same shape as input, except along the given axis. @@ -10438,8 +10443,12 @@ def merge_dicts(*dicts): tensor([[1, 2], [3, 4], [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) -.. function:: repeat_interleave(repeats) -> Tensor +.. function:: repeat_interleave(repeats, *, output_size=None) -> Tensor If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 7cbd5516e5634d..1cabb72b1e3878 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -97,6 +97,10 @@ class no_grad(_DecoratorContextManager): Also functions as a decorator. (Make sure to instantiate with parenthesis.) + .. note:: + No-grad is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. Example:: @@ -136,6 +140,10 @@ class enable_grad(_DecoratorContextManager): Also functions as a decorator. (Make sure to instantiate with parenthesis.) + .. note:: + enable_grad is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. Example:: @@ -178,6 +186,10 @@ class set_grad_enabled(object): (``False``). This can be used to conditionally enable gradients. + .. note:: + set_grad_enabled is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. Example:: @@ -222,6 +234,11 @@ class inference_mode(_DecoratorContextManager): Also functions as a decorator. (Make sure to instantiate with parenthesis.) + .. note:: + Inference mode is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + Args: mode (bool): Flag whether to enable or disable inference mode diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py new file mode 100644 index 00000000000000..52d240f1d3d27b --- /dev/null +++ b/torch/cpu/__init__.py @@ -0,0 +1 @@ +from . import amp diff --git a/torch/cpu/amp/__init__.py b/torch/cpu/amp/__init__.py new file mode 100644 index 00000000000000..e4fe09f55632e4 --- /dev/null +++ b/torch/cpu/amp/__init__.py @@ -0,0 +1 @@ +from .autocast_mode import autocast diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py new file mode 100644 index 00000000000000..207f5a20a9b82e --- /dev/null +++ b/torch/cpu/amp/autocast_mode.py @@ -0,0 +1,36 @@ +import torch +import functools +import warnings + +class autocast(object): + def __init__(self, enabled=True, dtype=torch.bfloat16): + supported_dtype = [torch.bfloat16] + if dtype not in supported_dtype : + warnings.warn("In CPU autocast, but the target dtype is not supported. Disable the autocast.") + warnings.warn("CPU Autocast only support dtype of torch.bfloat16 currently.") + enabled = False + dtype = torch.bfloat16 + self._enabled = enabled + self._dtype = dtype + + def __enter__(self): + self.prev = torch.is_autocast_cpu_enabled() + self.prev_dtype = torch.get_autocast_cpu_dtype() + torch.set_autocast_cpu_enabled(self._enabled) + torch.set_autocast_cpu_dtype(self._dtype) + torch.autocast_increment_nesting() + + def __exit__(self, *args): + # Drop the cache when we exit to a nesting level that's outside any instance of autocast. + if torch.autocast_decrement_nesting() == 0: + torch.clear_autocast_cache() + torch.set_autocast_cpu_enabled(self.prev) + torch.set_autocast_cpu_dtype(self.prev_dtype) + return False + + def __call__(self, func): + @functools.wraps(func) + def decorate_autocast(*args, **kwargs): + with self: + return func(*args, **kwargs) + return decorate_autocast diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index d7ccbb1f8fe1f3..183acae80d3462 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -51,7 +51,7 @@ PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) HANDLE_TH_ERRORS static torch::PythonArgParser parser({ "Device(Device device)", - "Device(std::string type, int64_t? index=-1)" + "Device(c10::string_view type, int64_t? index=-1)" }); torch::ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h index 7e662fc83b442d..e19b0ebe34ccfc 100644 --- a/torch/csrc/api/include/torch/enum.h +++ b/torch/csrc/api/include/torch/enum.h @@ -104,6 +104,7 @@ TORCH_ENUM_DECLARE(Tanh) TORCH_ENUM_DECLARE(ReLU) TORCH_ENUM_DECLARE(GELU) TORCH_ENUM_DECLARE(SiLU) +TORCH_ENUM_DECLARE(Mish) TORCH_ENUM_DECLARE(LeakyReLU) TORCH_ENUM_DECLARE(FanIn) TORCH_ENUM_DECLARE(FanOut) @@ -147,6 +148,7 @@ struct _compute_enum_name { TORCH_ENUM_PRETTY_PRINT(ReLU) TORCH_ENUM_PRETTY_PRINT(GELU) TORCH_ENUM_PRETTY_PRINT(SiLU) + TORCH_ENUM_PRETTY_PRINT(Mish) TORCH_ENUM_PRETTY_PRINT(LeakyReLU) TORCH_ENUM_PRETTY_PRINT(FanIn) TORCH_ENUM_PRETTY_PRINT(FanOut) diff --git a/torch/csrc/api/include/torch/fft.h b/torch/csrc/api/include/torch/fft.h index a176d37d0e8be0..21f7c281b188e8 100644 --- a/torch/csrc/api/include/torch/fft.h +++ b/torch/csrc/api/include/torch/fft.h @@ -16,7 +16,7 @@ namespace fft { inline Tensor fft(const Tensor& self, c10::optional n=c10::nullopt, int64_t dim=-1, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_fft(self, n, dim, norm); } @@ -31,7 +31,7 @@ inline Tensor fft(const Tensor& self, inline Tensor ifft(const Tensor& self, c10::optional n=c10::nullopt, int64_t dim=-1, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_ifft(self, n, dim, norm); } @@ -46,7 +46,7 @@ inline Tensor ifft(const Tensor& self, inline Tensor fft2(const Tensor& self, c10::optional s=c10::nullopt, IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_fft2(self, s, dim, norm); } @@ -61,7 +61,7 @@ inline Tensor fft2(const Tensor& self, inline Tensor ifft2(const Tensor& self, c10::optional s=c10::nullopt, IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_ifft2(self, s, dim, norm); } @@ -76,7 +76,7 @@ inline Tensor ifft2(const Tensor& self, inline Tensor fftn(const Tensor& self, c10::optional s=c10::nullopt, c10::optional dim=c10::nullopt, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_fftn(self, s, dim, norm); } @@ -91,7 +91,7 @@ inline Tensor fftn(const Tensor& self, inline Tensor ifftn(const Tensor& self, c10::optional s=c10::nullopt, c10::optional dim=c10::nullopt, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_ifftn(self, s, dim, norm); } @@ -107,7 +107,7 @@ inline Tensor ifftn(const Tensor& self, inline Tensor rfft(const Tensor& self, c10::optional n=c10::nullopt, int64_t dim=-1, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_rfft(self, n, dim, norm); } @@ -125,7 +125,7 @@ inline Tensor rfft(const Tensor& self, inline Tensor irfft(const Tensor& self, c10::optional n=c10::nullopt, int64_t dim=-1, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_irfft(self, n, dim, norm); } @@ -140,7 +140,7 @@ inline Tensor irfft(const Tensor& self, inline Tensor rfft2(const Tensor& self, c10::optional s=c10::nullopt, IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_rfft2(self, s, dim, norm); } @@ -155,7 +155,7 @@ inline Tensor rfft2(const Tensor& self, inline Tensor irfft2(const Tensor& self, c10::optional s=c10::nullopt, IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_irfft2(self, s, dim, norm); } @@ -170,7 +170,7 @@ inline Tensor irfft2(const Tensor& self, inline Tensor rfftn(const Tensor& self, c10::optional s=c10::nullopt, c10::optional dim=c10::nullopt, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_rfftn(self, s, dim, norm); } @@ -185,7 +185,7 @@ inline Tensor rfftn(const Tensor& self, inline Tensor irfftn(const Tensor& self, c10::optional s=c10::nullopt, c10::optional dim=c10::nullopt, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_irfftn(self, s, dim, norm); } @@ -204,7 +204,7 @@ inline Tensor irfftn(const Tensor& self, inline Tensor hfft(const Tensor& self, c10::optional n=c10::nullopt, int64_t dim=-1, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_hfft(self, n, dim, norm); } @@ -222,7 +222,7 @@ inline Tensor hfft(const Tensor& self, inline Tensor ihfft(const Tensor& self, c10::optional n=c10::nullopt, int64_t dim=-1, - c10::optional norm=c10::nullopt) { + c10::optional norm=c10::nullopt) { return torch::fft_ihfft(self, n, dim, norm); } diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 7b7dc7a65df35d..4afea2066d0005 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -44,19 +44,19 @@ inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { return torch::linalg_eigvals_out(result, self); } -inline std::tuple eigh(const Tensor& self, std::string uplo) { +inline std::tuple eigh(const Tensor& self, c10::string_view uplo) { return torch::linalg_eigh(self, uplo); } -inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { +inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, c10::string_view uplo) { return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo); } -inline Tensor eigvalsh(const Tensor& self, std::string uplo) { +inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { return torch::linalg_eigvalsh(self, uplo); } -inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { +inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, c10::string_view uplo) { return torch::linalg_eigvalsh_out(result, self, uplo); } @@ -68,7 +68,7 @@ inline Tensor& householder_product_out(Tensor& result, const Tensor& input, cons return torch::linalg_householder_product_out(result, input, tau); } -inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { +inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { return torch::linalg_lstsq(self, b, cond, driver); } @@ -76,7 +76,7 @@ inline Tensor norm(const Tensor& self, const optional& opt_ord, optional return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } -inline Tensor norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor norm(const Tensor& self, c10::string_view ord, optional opt_dim, bool keepdim, optional opt_dtype) { return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype); } @@ -84,7 +84,7 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, const optional opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor& norm_out(Tensor& result, const Tensor& self, c10::string_view ord, optional opt_dim, bool keepdim, optional opt_dtype) { return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } @@ -260,22 +260,22 @@ inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { /// Computes eigenvalues and eigenvectors /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigh -inline std::tuple eigh(const Tensor& self, std::string uplo) { +inline std::tuple eigh(const Tensor& self, c10::string_view uplo) { return detail::eigh(self, uplo); } -inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { +inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, c10::string_view uplo) { return detail::eigh_out(eigvals, eigvecs, self, uplo); } /// Computes eigenvalues /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvalsh -inline Tensor eigvalsh(const Tensor& self, std::string uplo) { +inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { return detail::eigvalsh(self, uplo); } -inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { +inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, c10::string_view uplo) { return detail::eigvalsh_out(result, self, uplo); } @@ -290,7 +290,7 @@ inline Tensor& householder_product_out(Tensor& result, const Tensor& input, cons return detail::householder_product_out(result, input, tau); } -inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { +inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { return detail::lstsq(self, b, cond, driver); } @@ -300,7 +300,7 @@ inline Tensor linalg_norm(const Tensor& self, const optional& opt_ord, o } /// DEPRECATED -inline Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor linalg_norm(const Tensor& self, c10::string_view ord, optional opt_dim, bool keepdim, optional opt_dtype) { return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); } @@ -310,7 +310,7 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, const optiona } /// DEPRECATED -inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, c10::string_view ord, optional opt_dim, bool keepdim, optional opt_dtype) { return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 3230cacf42554e..a0487c61835efa 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -348,6 +348,12 @@ inline Tensor silu(const Tensor& input) { // ============================================================================ +inline Tensor mish(const Tensor& input) { + return torch::mish(input); +} + +// ============================================================================ + inline Tensor prelu(const Tensor& input, const Tensor& weight) { return torch::prelu(input, weight); } diff --git a/torch/csrc/api/include/torch/nn/functional/embedding.h b/torch/csrc/api/include/torch/nn/functional/embedding.h index 0c8a814cd38d9a..f391792f2a617b 100644 --- a/torch/csrc/api/include/torch/nn/functional/embedding.h +++ b/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -6,8 +6,11 @@ namespace torch { namespace nn { namespace functional { -inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) { - return torch::one_hot(tensor, num_classes); +inline Tensor one_hot( + const Tensor& tensor, + int64_t num_classes = -1, + ScalarType dtype = c10::kLong) { + return torch::one_hot(tensor, num_classes, dtype); } #ifndef DOXYGEN_SHOULD_SKIP_THIS diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h index 64b96b2bed3a55..865914ec887b25 100644 --- a/torch/csrc/api/include/torch/nn/modules/activation.h +++ b/torch/csrc/api/include/torch/nn/modules/activation.h @@ -606,6 +606,28 @@ class TORCH_API SiLUImpl : public torch::nn::Cloneable { /// module storage semantics. TORCH_MODULE(SiLU); +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies mish over a given input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.Mish to learn +/// about the exact behavior of this module. +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_API MishImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Mish` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `MishImpl`. +/// See the documentation for `MishImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Mish); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies sigmoid over a given input. diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index 74f3719ad4b4f1..3c4d2b8c98f50b 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -308,6 +308,18 @@ void SiLUImpl::pretty_print(std::ostream& stream) const { // ============================================================================ +Tensor MishImpl::forward(const Tensor& input) { + return F::mish(input); +} + +void MishImpl::reset() {} + +void MishImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::Mish()"; +} + +// ============================================================================ + Tensor SigmoidImpl::forward(const Tensor& input) { return torch::sigmoid(input); } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 48cefc1506a5c3..110f81e34cda22 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -351,7 +351,7 @@ Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { return handle_r_to_c(self_st, out); } -Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional& rounding_mode) { +Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional& rounding_mode) { if (rounding_mode.has_value()) { return at::zeros_like(grad, grad.options().dtype(self_st)); } @@ -364,7 +364,7 @@ Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) { return div_tensor_self_backward(grad, other, self_st, c10::nullopt); } -Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional& rounding_mode) { +Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional& rounding_mode) { if (rounding_mode.has_value()) { return at::zeros_like(grad, grad.options().dtype(other.scalar_type())); } @@ -1140,6 +1140,15 @@ Tensor infinitely_differentiable_silu_backward( return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid)); } +Tensor infinitely_differentiable_mish_backward( + const Tensor& grad_output, + const Tensor& input) { + const Tensor sigmoid = input.sigmoid(); + const Tensor softplus = input.exp().log1p(); + const Tensor tanh_softplus = softplus.tanh(); + return grad_output * (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus)); +} + Tensor infinitely_differentiable_logit_backward( const Tensor& grad, const Tensor& self, @@ -2478,7 +2487,7 @@ Tensor eigh_backward(const std::vector &grads, const } Tensor linalg_qr_backward(const std::vector &grads, const Tensor& self, - std::string mode, const Tensor& q, const Tensor& r){ + c10::string_view mode, const Tensor& q, const Tensor& r){ // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool compute_q, reduced; std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 6565c4bdfc904b..44778f4bb7ca72 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -59,8 +59,8 @@ at::Tensor angle_backward(at::Tensor grad, const at::Tensor& self); at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st); at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st); at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other); -at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional& rounding_mode); -at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional& rounding_mode); +at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional& rounding_mode); +at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional& rounding_mode); at::Tensor mvlgamma_backward(at::Tensor grad, const at::Tensor & self, int64_t p); at::Tensor permute_backwards(const at::Tensor & grad, at::IntArrayRef fwd_dims); at::Tensor rad2deg_backward(const at::Tensor& grad); @@ -110,6 +110,7 @@ at::Tensor max_pool_double_backward(const at::Tensor & grad, const at::Tensor & at::Tensor glu_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, int64_t dim); at::Tensor glu_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, int64_t dim); at::Tensor infinitely_differentiable_silu_backward(const at::Tensor& grad_output, const at::Tensor& input); +at::Tensor infinitely_differentiable_mish_backward(const at::Tensor& grad_output, const at::Tensor& input); Tensor infinitely_differentiable_logit_backward(const Tensor& grad, const Tensor& self, c10::optional eps); at::Tensor kl_div_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, bool log_target); at::Tensor binary_cross_entropy_with_logits_target_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction); @@ -161,7 +162,7 @@ std::tuple _trilinear_backward(const Tensor& grad_out, c IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, IntArrayRef sumdim, int64_t unroll_dim, std::array grad_mask); Tensor linalg_qr_backward(const std::vector &grads, const Tensor& self, - std::string mode, const Tensor& Q, const Tensor& R); + c10::string_view mode, const Tensor& Q, const Tensor& R); Tensor eig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& lambda, const Tensor& v); Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor& det); diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 5a70738cbd351d..d0d9ec9490fb03 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -963,7 +963,7 @@ void Engine::initialize_device_threads_pool() { std::call_once(start_device_threads_flag_, &Engine::start_device_threads, this); } -std::shared_ptr Engine::execute_with_graph_task( +c10::intrusive_ptr Engine::execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root, InputBuffer&& input_buffer) { diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 3fef22c6dcfd78..637192035c0ce7 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -154,7 +154,7 @@ struct GraphTask: std::enable_shared_from_this { // Future representing the completion of the graph task. Notified when all // tasks are done. - std::shared_ptr future_result_; + c10::intrusive_ptr future_result_; // Final callbacks installed during execution of this GraphTask std::vector> final_callbacks_; @@ -174,7 +174,7 @@ struct GraphTask: std::enable_shared_from_this { reentrant_depth_(reentrant_depth), exit_on_error_(exit_on_error), cpu_ready_queue_(std::move(cpu_ready_queue)), - future_result_(std::make_shared(c10::ListType::create(c10::TensorType::get()))) {} + future_result_(c10::make_intrusive(c10::ListType::create(c10::TensorType::get()))) {} private: // run GraphTask post processing void exec_post_processing(); @@ -294,7 +294,7 @@ struct TORCH_API Engine { // // NB: This API should only be used by internal autograd specific // machinery and shouldn't be exposed to users in anyway. - virtual std::shared_ptr execute_with_graph_task( + virtual c10::intrusive_ptr execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root, InputBuffer&& input_buffer); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 38f47d5d1346a6..9d3a1f87d50bd5 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include @@ -278,6 +279,57 @@ static PyObject * is_autocast_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } +static PyObject * set_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (!PyBool_Check(arg)) { + throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); + } + at::autocast::set_cpu_enabled(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject * is_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (at::autocast::is_cpu_enabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +static PyObject * set_autocast_cpu_dtype(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (!THPDtype_Check(arg)) { + throw TypeError( + "dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name); + } + at::ScalarType targetType = reinterpret_cast(arg)->scalar_type; + at::autocast::set_autocast_cpu_dtype(targetType); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static const char* scalarTypeName(const at::ScalarType type) { + switch (type) { +#define DEFINE_CASE(ctype, name) \ + case at::ScalarType::name: \ + return #ctype; + AT_FORAUTOCAST_SCALAR_TYPES(DEFINE_CASE) +#undef DEFINE_CASE + default: + throw std::runtime_error("unknown scalar type for autocast"); + } +} + +static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){ + HANDLE_TH_ERRORS + at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype(); + return THPDtype_New(current_dtype, scalarTypeName(current_dtype)); + END_HANDLE_TH_ERRORS +} + static PyObject * clear_autocast_cache(PyObject* _unused, PyObject *arg) { HANDLE_TH_ERRORS at::autocast::clear_cache(); @@ -377,6 +429,10 @@ static PyMethodDef methods[] = { // NOLINT {"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr}, {"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr}, {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, + {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr}, + {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr}, + {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr}, + {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr}, {"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr}, {"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr}, {"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr}, diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 984ca8eee3b73a..965177d449dbca 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -123,7 +123,7 @@ variable_list PythonEngine::execute( } } -std::shared_ptr PythonEngine::execute_with_graph_task( +c10::intrusive_ptr PythonEngine::execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root, InputBuffer&& input_buffer) { diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h index b968eea00a1f8b..d9bb20d5b6f4ab 100644 --- a/torch/csrc/autograd/python_engine.h +++ b/torch/csrc/autograd/python_engine.h @@ -27,7 +27,7 @@ struct PythonEngine : public Engine { bool accumulate_grad, const edge_list& outputs = {}) override; - std::shared_ptr execute_with_graph_task( + c10::intrusive_ptr execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root, InputBuffer&& input_buffer) override; diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 8e14ee76972460..520b628c056d6a 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -33,6 +33,8 @@ #include #include +#include +#include #include #include #include @@ -41,6 +43,33 @@ using namespace at; using namespace torch; using namespace torch::autograd; +namespace { + +std::string concrete_name_fn(const c10::impl::PyInterpreter* self) { + std::stringstream ss; + ss << self; + return ss.str(); +} + +class PyInterpreterHolder { + public: + PyInterpreterHolder() + : impl_(new c10::impl::PyInterpreter(&concrete_name_fn)) {} + // NB: intentionally leaks the memory + ~PyInterpreterHolder() { + impl_->disarm(); + } + c10::impl::PyInterpreter* get() const noexcept { + return impl_; + } + + private: + c10::impl::PyInterpreter* impl_; +}; +PyInterpreterHolder self_interpreter; + +} // anonymous namespace + namespace py = pybind11; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -55,49 +84,24 @@ static const char* VOLATILE_WARNING = "volatile was removed and now has no effect. Use " "`with torch.no_grad():` instead."; -#ifdef USE_DEPLOY -// used only in libtorch_deployinterpreter.so -// there are muliple copies of the python interpreter that -// can shared Tensors, so rather than use their internal pointer -// to a PyObject use a library-local map. -static std::unordered_map impl_to_pyobj; - -void set_pyobj(const Variable& self, PyObject* pyobj) { - TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor"); - void* key = self.unsafeGetTensorImpl(); - if (!pyobj) { - impl_to_pyobj.erase(key); - return; - } - impl_to_pyobj[key] = pyobj; -} - -PyObject* pyobj(const Variable& self) { - TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor"); - auto it = impl_to_pyobj.find(self.unsafeGetTensorImpl()); - return it == impl_to_pyobj.end() ? nullptr : it->second; -} -#else -void set_pyobj(const Variable& self, PyObject* pyobj) { - TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor"); - self.unsafeGetTensorImpl()->set_pyobj(pyobj); -} - -PyObject* pyobj(const Variable& self) { - TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor"); - return self.unsafeGetTensorImpl()->pyobj(); -} -#endif - -// Creates a new Python object for a Variable. The Variable must not already -// have a PyObject* associated with it. -static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var) -{ +// Creates a new Python object for a Variable. The status parameter +// specifies what the interpreter tag status on the object is; for +// example, if you ran check_pyobj, the return optional of this object +// tells you if the tensor was already tagged or not so you can pass +// TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where +// var came from and can directly assert that it's DEFINITELY_UNINITIALIZED. +// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. +static PyObject* THPVariable_NewWithVar( + PyTypeObject* type, + Variable var, + c10::impl::PyInterpreterStatus status) { PyObject* obj = type->tp_alloc(type, 0); if (obj) { auto v = (THPVariable*) obj; new (&v->cdata) Variable(std::move(var)); - set_pyobj(v->cdata, obj); + // cannot use var as it is moved out of + THPVariable_Unpack(v).unsafeGetTensorImpl()->init_pyobj( + self_interpreter.get(), obj, status); } return obj; } @@ -108,12 +112,26 @@ PyObject * THPVariable_Wrap(Variable var) Py_RETURN_NONE; } - if (auto obj = pyobj(var)) { - Py_INCREF(obj); - return obj; + c10::optional mb_obj = + var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()); + c10::impl::PyInterpreterStatus status; + if (mb_obj.has_value()) { + auto obj = *mb_obj; + if (obj) { + Py_INCREF(obj); + return obj; + } + // TODO: a better invariant is that if we tagged, we MUST have a valid + // PyObject. That's PyObject preservation + // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR + // being a thing, the PyObject field will get cleared when all references + // to the Python object are removed. + status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; + } else { + status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; } - - return THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var)); + return THPVariable_NewWithVar( + (PyTypeObject*)THPVariableClass, std::move(var), status); } static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg) @@ -165,7 +183,10 @@ static int THPVariable_clear(THPVariable *self) // objects stay live, buster! See // https://github.com/pytorch/pytorch/issues/22884 for an example of // this actually showing up. - set_pyobj(self->cdata, nullptr); + // + // [torchdeploy] Note that we DON'T clear the interpreter field. Once on an + // interpreter, always on an interpreter. + tensor.unsafeGetTensorImpl()->unchecked_clear_pyobj(self_interpreter.get()); } self->cdata.reset(); return 0; @@ -194,7 +215,10 @@ static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObje if (!PyType_Check(cls)) { throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name); } - return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + self.alias(), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); END_HANDLE_TH_ERRORS } @@ -209,7 +233,8 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P if (!PyType_Check(cls)) { throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name); } - auto data = r.tensor(1).detach(); + auto data = + r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED) // We set `data`'s `allow_tensor_metadata_change` to true here, because we want to // allow the following use case for backward compatibility: // @@ -221,7 +246,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P // ``` data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); auto var = data.set_requires_grad(r.toBool(2)); - return THPVariable_NewWithVar((PyTypeObject*)cls, std::move(var)); + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + std::move(var), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); END_HANDLE_TH_ERRORS } @@ -951,11 +979,15 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs TORCH_CHECK(type != &THPVariableType, "Cannot directly construct _TensorBase; subclass it and then construct that"); jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs); - return THPVariable_NewWithVar(type, std::move(tensor)); + // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was + // given a raw pointer that will refcount bump + return THPVariable_NewWithVar( + type, + std::move(tensor), + c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED); END_HANDLE_TH_ERRORS } - int THPVariableMetaType_init(PyObject *cls, PyObject *args, PyObject *kwargs) { if (PyType_Type.tp_init(cls, args, kwargs) < 0) { return -1; diff --git a/torch/csrc/cuda/shared/cudart.cpp b/torch/csrc/cuda/shared/cudart.cpp index a8f80a35855d2c..30a43bed053467 100644 --- a/torch/csrc/cuda/shared/cudart.cpp +++ b/torch/csrc/cuda/shared/cudart.cpp @@ -6,6 +6,7 @@ #else #include #endif +#include namespace torch { namespace cuda { namespace shared { @@ -38,6 +39,13 @@ void initCudartBindings(PyObject* module) { #ifndef __HIP_PLATFORM_HCC__ cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize); #endif + cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair { + C10_CUDA_CHECK(cudaGetDevice(&device)); + size_t device_free; + size_t device_total; + cudaMemGetInfo(&device_free, &device_total); + return {device_free, device_total}; + }); } } // namespace shared diff --git a/torch/csrc/deploy/test_deploy.cpp b/torch/csrc/deploy/test_deploy.cpp index 5b6b75d3f852b5..d22dcc8d0f12cf 100644 --- a/torch/csrc/deploy/test_deploy.cpp +++ b/torch/csrc/deploy/test_deploy.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -74,11 +75,11 @@ TEST(TorchpyTest, MultiSerialSimpleModel) { std::vector outputs; for (size_t i = 0; i < ninterp; i++) { - outputs.push_back(model({input}).toTensor()); + outputs.push_back(model({input.alias()}).toTensor()); } // Generate reference - auto ref_output = ref_model.forward({input}).toTensor(); + auto ref_output = ref_model.forward({input.alias()}).toTensor(); // Compare all to reference for (size_t i = 0; i < ninterp; i++) { @@ -116,9 +117,9 @@ TEST(TorchpyTest, ThreadedSimpleModel) { futures.push_back(std::async(std::launch::async, [&model]() { auto input = torch::ones({10, 20}); for (int i = 0; i < 100; ++i) { - model({input}).toTensor(); + model({input.alias()}).toTensor(); } - auto result = model({input}).toTensor(); + auto result = model({input.alias()}).toTensor(); return result; })); } @@ -127,7 +128,7 @@ TEST(TorchpyTest, ThreadedSimpleModel) { } // Generate reference - auto ref_output = ref_model.forward({input}).toTensor(); + auto ref_output = ref_model.forward({input.alias()}).toTensor(); // Compare all to reference for (size_t i = 0; i < nthreads; i++) { @@ -166,3 +167,54 @@ TEST(TorchpyTest, AcquireMultipleSessionsInDifferentPackages) { path("RESNET", "torch/csrc/deploy/example/generated/resnet")); auto I1 = p1.acquire_session(); } + +TEST(TorchpyTest, TensorSharingNotAllowed) { + size_t nthreads = 2; + torch::deploy::InterpreterManager m(nthreads); + // generate a tensor from one interpreter + auto I0 = m.all_instances()[0].acquire_session(); + auto I1 = m.all_instances()[1].acquire_session(); + auto obj = I0.global("torch", "empty")({I0.from_ivalue(2)}); + auto t = obj.toIValue().toTensor(); + // try to feed it to the other interpreter, should error + ASSERT_THROW(I1.global("torch", "sigmoid")({t}), c10::Error); +} + +TEST(TorchpyTest, TaggingRace) { + // At time of writing, this takes about 7s to run on DEBUG=1. I think + // this is OK, but feel free to fiddle with the knobs here to reduce the + // runtime + constexpr int64_t trials = 4; + constexpr int64_t nthreads = 16; + torch::deploy::InterpreterManager m(nthreads); + for (int64_t n = 0; n < trials; n++) { + at::Tensor t = torch::empty(2); + std::atomic success(0); + std::atomic failed(0); + at::parallel_for(0, nthreads, 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; i++) { + auto I = m.all_instances()[i].acquire_session(); + try { + I.from_ivalue(t); + success++; + } catch (const c10::Error& e) { + failed++; + } + } + }); + ASSERT_EQ(success, 1); + ASSERT_EQ(failed, nthreads - 1); + } +} + +TEST(TorchpyTest, DisarmHook) { + at::Tensor t = torch::empty(2); + { + torch::deploy::InterpreterManager m(1); + auto I = m.acquire_one(); + I.from_ivalue(t); + } // unload the old interpreter + torch::deploy::InterpreterManager m(1); + auto I = m.acquire_one(); + ASSERT_THROW(I.from_ivalue(t), c10::Error); // NOT a segfault +} diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index a99302d171b51f..f94e9455dbab0d 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -308,6 +308,10 @@ void DistEngine::computeDependencies( } } + // Set graph task owner in a single thread since concurrent access to + // 'owner_' field is not permitted. + graphTask->owner_ = torch::autograd::CPU_DEVICE; + // Let autograd context take ownership of the GraphTask. autogradContext->setGraphTask(std::move(graphTask)); } @@ -330,7 +334,6 @@ void DistEngine::execute_graph_task_until_ready_queue_empty( cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks); torch::autograd::set_device(torch::autograd::CPU_DEVICE); - graph_task->owner_ = torch::autograd::CPU_DEVICE; while (!cpu_ready_queue->empty()) { std::shared_ptr local_graph_task; { diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 12ea0902e911e3..984ea6d1b0bfc4 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -12,6 +12,7 @@ #ifdef USE_C10D_GLOO #include +#include #endif #ifdef USE_C10D_NCCL @@ -1194,6 +1195,10 @@ that adds a prefix to each key inserted to the store. }, py::arg("timeout") = ::c10d::kUnsetTimeout, py::arg("wait_all_ranks") = false, + py::call_guard()) + .def( + "_get_backend_name", + &::c10d::ProcessGroup::getBackendName, py::call_guard()); // base ProcessGroup::Options binding @@ -1301,6 +1306,23 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("timeout") = kProcessGroupDefaultTimeout, py::call_guard()) .def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions); + + // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process + // group. It can be used to validate collective calls across processes by + // checking the op type and input tensor shapes. + auto processGroupWrapper = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>( + module, "_ProcessGroupWrapper", processGroup) + .def( + py::init([](const c10::intrusive_ptr<::c10d::ProcessGroup>& pg, + const c10::intrusive_ptr<::c10d::ProcessGroupGloo>& + gloo_pg) { + return c10::make_intrusive<::c10d::ProcessGroupWrapper>( + pg, gloo_pg); + }), + py::arg("pg"), + py::arg("gloo_pg"), + py::call_guard()); #endif #ifdef USE_C10D_NCCL diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 2ba1915aa91d1e..720a6304827cf3 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -307,37 +307,41 @@ void PyRRef::backward(int64_t autogradContextId, bool retainGraph) { backward(autogradContextId, retainGraph, rref_); } -void PyRRef::backward( +void PyRRef::backwardOwnerRRef( int64_t autogradContextId, bool retainGraph, - const c10::intrusive_ptr& rref) { - if (rref->isOwner()) { - auto value = - c10::static_intrusive_pointer_cast(rref)->getValue(); - - // If we have a PyObj, retrieve the underlying tensor. - if (rref->isPyObj()) { - py::gil_scoped_acquire gil; - py::object obj = torch::jit::toPyObject(value); - try { - value = torch::jit::toIValue(obj, c10::TensorType::get()); - } catch (py::cast_error& e) { - throw std::runtime_error( - "RRef should contain a tensor for .backward()"); - } + IValue value) { + // If we have a PyObj, retrieve the underlying tensor. + if (value.isPyObject()) { + py::gil_scoped_acquire gil; + py::object obj = torch::jit::toPyObject(value); + try { + value = torch::jit::toIValue(obj, c10::TensorType::get()); + } catch (py::cast_error& e) { + throw std::runtime_error("RRef should contain a tensor for .backward()"); } + } - TORCH_CHECK( - value.isTensor(), "RRef should contain a tensor for .backward()"); - auto root = value.toTensor(); + TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()"); + auto root = value.toTensor(); - if (autogradContextId == -1) { - torch::autograd::backward({root}); - } else { - torch::distributed::autograd::backward( - autogradContextId, {root}, retainGraph); - } + if (autogradContextId == -1) { + torch::autograd::backward({root}); + } else { + torch::distributed::autograd::backward( + autogradContextId, {root}, retainGraph); + } +} +void PyRRef::backward( + int64_t autogradContextId, + bool retainGraph, + const c10::intrusive_ptr& rref) { + if (rref->isOwner()) { + backwardOwnerRRef( + autogradContextId, + retainGraph, + c10::static_intrusive_pointer_cast(rref)->getValue()); } else { TORCH_CHECK( autogradContextId != -1, diff --git a/torch/csrc/distributed/rpc/py_rref.h b/torch/csrc/distributed/rpc/py_rref.h index af1b8d25aabbf5..7710c6acbc728b 100644 --- a/torch/csrc/distributed/rpc/py_rref.h +++ b/torch/csrc/distributed/rpc/py_rref.h @@ -66,6 +66,12 @@ class PYBIND11_EXPORT PyRRef { bool retainGraph, const c10::intrusive_ptr& rref); + // Specialization of backward if the rref is an OwnerRRef. + static void backwardOwnerRRef( + int64_t autogradContextId, + bool retainGraph, + IValue value); + private: c10::intrusive_ptr rref_; c10::optional> profilingFuture_; diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 8feffca6562da1..c6dfcaf949d0b6 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include namespace torch { namespace distributed { @@ -85,60 +85,66 @@ std::unique_ptr deserializePythonRpcCommandReference( } } -void processPythonExecution( - const py::object& pyFn, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, - bool isAsyncExecution, - std::function&)> postProcessing) { - std::shared_ptr pyFuture; +SerializedPyObj serializePyObject(IValue value) { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); - { - py::gil_scoped_acquire acquire; - auto result = pythonRpcHandler.runPythonUdf(pyFn); - - if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) { - // Hit exception when running the user function or there is no async - // execution. Not releasing GIL before serialize to avoid an additional - // context switch. - postProcessing(result, messageId, pythonRpcHandler, responseFuture); - return; - } - - try { - pyFuture = result.cast>(); - } catch (const py::cast_error& e) { - auto type = result.get_type(); - auto errMsg = c10::str( - e.what(), - ". Functions decorated with @rpc.async_function must return a " - "torch.futures.Future object, but got ", - type.attr("__module__").cast(), - ".", - type.attr("__qualname__").cast()); - throw std::runtime_error(errMsg); - } + // Need this GIL to guard jit::toPyObj and destruct its returned + // py::object + py::gil_scoped_acquire acquire; + try { + return pythonRpcHandler.serialize(jit::toPyObject(value)); + } catch (py::error_already_set& e) { + // py::error_already_set requires GIL to destruct, take special care. + auto err = std::runtime_error(e.what()); + e.restore(); + PyErr_Clear(); + throw err; } - - pyFuture->fut->addCallback([messageId, - responseFuture, - postProcessing{std::move(postProcessing)}, - &pythonRpcHandler](JitFuture& jitFuture) { - py::gil_scoped_acquire acquire; - postProcessing( - jit::toPyObject(jitFuture.value()), - messageId, - pythonRpcHandler, - responseFuture); - }); } } // anonymous namespace +c10::intrusive_ptr RequestCallbackImpl::runPythonFunction( + const py::object& function, + bool isAsyncExecution) const { + auto& pythonRpcHandler = PythonRpcHandler::getInstance(); + py::gil_scoped_acquire acquire; + + py::object result; + try { + result = pythonRpcHandler.runPythonUdf(function); + } catch (py::error_already_set& e) { + // py::error_already_set requires GIL to destruct, take special care. + auto future = + asFuture(std::make_exception_ptr(std::runtime_error(e.what()))); + e.restore(); + PyErr_Clear(); + return future; + } catch (std::exception& e) { + return asFuture(std::current_exception()); + } + + // After sync exection or failed async execution return the value as-is. + if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) { + return asFuture( + c10::ivalue::ConcretePyObjectHolder::create(result), + at::PyObjectType::get()); + } + + try { + return result.cast().fut; + } catch (const py::cast_error& e) { + auto type = result.get_type(); + auto errMsg = c10::str( + e.what(), + ". Functions decorated with @rpc.async_function must return a " + "torch.futures.Future object, but got ", + type.attr("__module__").cast(), + ".", + type.attr("__qualname__").cast()); + return asFuture(std::make_exception_ptr(std::runtime_error(errMsg))); + } +} + std::unique_ptr RequestCallbackImpl:: deserializePythonRpcCommand( std::unique_ptr rpc, @@ -147,311 +153,87 @@ std::unique_ptr RequestCallbackImpl:: return pythonRpc ? std::move(pythonRpc) : std::move(rpc); } -void RequestCallbackImpl::processScriptCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { +c10::intrusive_ptr RequestCallbackImpl::processScriptCall( + RpcCommandBase& rpc) const { auto& scriptCall = static_cast(rpc); - auto& stack = scriptCall.stackRef(); - if (processScriptCallOp(scriptCall, markComplete, stack)) { - return; - } - // runAsync() starts in the calling thread, but may return an uncompleted - // future (though for non-async code, it will typically be completed). - // If it was async, our callback will typically be invoked by the - // continuation on an at::launch() thread. - auto jitFuture = PythonRpcHandler::getInstance() - .jitCompilationUnit() - ->get_function(scriptCall.qualifiedName()) - .runAsync(stack); - - // Fastpath: avoid callbacks if not neeeded. - if (jitFuture->completed() && !scriptCall.isAsyncExecution()) { - markComplete(std::move(ScriptResp(jitFuture->value())).toMessage()); - return; + c10::intrusive_ptr future; + if (scriptCall.hasOp()) { + future = runJitOperator(*scriptCall.op(), scriptCall.stackRef()); + } else { + future = runJitFunction( + scriptCall.qualifiedName(), + scriptCall.stackRef(), + scriptCall.isAsyncExecution()); } - jitFuture->addCallback([responseFuture, - messageId, - isAsyncExecution = scriptCall.isAsyncExecution(), - markComplete](JitFuture& jitFutureCaptured) { - try { - JitFuture& jitFuture = isAsyncExecution - ? *jitFutureCaptured.value().toFuture() - : jitFutureCaptured; - - // Setup response callback appropriately. - auto responseCb = [responseFuture, messageId](JitFuture& jitFuture) { - try { - Message m = ScriptResp(jitFuture.value()).toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - } catch (const std::exception& /* unused */) { - responseFuture->setError(std::current_exception()); - } - }; - - // Call inline if we don't have async execution. - isAsyncExecution ? jitFuture.addCallback(responseCb) - : responseCb(jitFuture); - } catch (const std::exception& /* unused */) { - responseFuture->setError(std::current_exception()); - } - }); + return future->then( + [](JitFuture& jitFuture) { + return c10::make_intrusive( + ScriptResp(jitFuture.value()).toMessage()); + }, + c10::getCustomClassType>()); } -void RequestCallbackImpl::processPythonCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { +c10::intrusive_ptr RequestCallbackImpl::processPythonCall( + RpcCommandBase& rpc) const { auto& upc = static_cast(rpc); - try { - processPythonExecution( - upc.pythonUdf(), - messageId, - responseFuture, - upc.isAsyncExecution(), - [](const py::object& result, - const int64_t messageId, - PythonRpcHandler& pythonRpcHandler, - const c10::intrusive_ptr& responseFuture) { - // Check we have GIL. - DCHECK(PyGILState_Check()); - - auto serializedPyObj = pythonRpcHandler.serialize(result); - py::gil_scoped_release release; - auto m = - std::move(PythonResp(std::move(serializedPyObj))).toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - }); - } catch (std::exception& e) { - responseFuture->markCompleted(IValue(c10::make_intrusive( - createExceptionResponse(e.what(), messageId)))); - } + auto future = runPythonFunction(upc.pythonUdf(), upc.isAsyncExecution()); + + return future->then( + [](JitFuture& future) { + return c10::make_intrusive( + PythonResp(serializePyObject(future.value())).toMessage()); + }, + c10::getCustomClassType>()); } -TypePtr RequestCallbackImpl::getScriptRemoteCallType( - ScriptRemoteCall& scriptRemoteCall) const { - TypePtr returnType; +c10::intrusive_ptr RequestCallbackImpl::processScriptRemoteCall( + RpcCommandBase& rpc) const { + auto& scriptRemoteCall = static_cast(rpc); + + c10::intrusive_ptr future; if (scriptRemoteCall.hasOp()) { - returnType = scriptRemoteCall.op()->schema().returns()[0].type(); + future = + runJitOperator(*scriptRemoteCall.op(), scriptRemoteCall.stackRef()); } else { - returnType = PythonRpcHandler::getInstance() - .jitCompilationUnit() - ->get_function(scriptRemoteCall.qualifiedName()) - .getSchema() - .returns() - .at(0) - .type(); - } - return returnType; -} - -void RequestCallbackImpl::processScriptRemoteCall( - ScriptRemoteCall& scriptRemoteCall, - const std::function& postProcessing, - std::vector& stack, - const c10::intrusive_ptr& ownerRRef) const { - if (processScriptRemoteCallOp( - scriptRemoteCall, postProcessing, stack, ownerRRef)) { - return; + future = runJitFunction( + scriptRemoteCall.qualifiedName(), + scriptRemoteCall.stackRef(), + scriptRemoteCall.isAsyncExecution()); } - auto isAsyncExecution = scriptRemoteCall.isAsyncExecution(); - auto asyncPostProcessing = [ownerRRef, postProcessing, isAsyncExecution]( - c10::ivalue::Future& jitFuture) mutable { - // The user function will return a JIT future, install - // setRRefValue and postProcessing to that valueFuture - try { - JitFuture& valueJitFuture = - isAsyncExecution ? *jitFuture.value().toFuture() : jitFuture; - - // Setup callback. - auto setRRefValue = [ownerRRef, - postProcessing](JitFuture& valueJitFuture) mutable { - try { - ownerRRef->setValue(valueJitFuture.value()); - } catch (const std::exception& e) { - ownerRRef->setError(std::current_exception()); - } - postProcessing(); - }; - - // Call inline if not async execution. - isAsyncExecution ? valueJitFuture.addCallback(setRRefValue) - : setRRefValue(valueJitFuture); - } catch (std::exception& e) { - ownerRRef->setError(std::current_exception()); - postProcessing(); - } - }; - - c10::intrusive_ptr jitFuture; - try { - jitFuture = PythonRpcHandler::getInstance() - .jitCompilationUnit() - ->get_function(scriptRemoteCall.qualifiedName()) - .runAsync(stack); - if (jitFuture->completed()) { // short-cut. - asyncPostProcessing(*jitFuture); - return; - } - } catch (const std::exception& e) { - asyncPostProcessing(*jitFuture); - return; - } - jitFuture->addCallback( - [asyncPostProcessing{std::move(asyncPostProcessing)}]( - JitFuture& jitFuture) mutable { asyncPostProcessing(jitFuture); }); + return assignOwnerRRef( + scriptRemoteCall.retRRefId(), + scriptRemoteCall.retForkId(), + std::move(future), + /*lsctx=*/nullptr); } -void RequestCallbackImpl::processPythonRemoteCall( +c10::intrusive_ptr RequestCallbackImpl::processPythonRemoteCall( RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr lsctx) const { auto& uprc = static_cast(rpc); + auto future = runPythonFunction(uprc.pythonUdf(), uprc.isAsyncExecution()); - const auto& rrefId = uprc.rrefId(); - const auto& forkId = uprc.forkId(); - auto& ctx = RRefContext::getInstance(); - - c10::intrusive_ptr ownerRRef; - if (rrefId == forkId) { - // Creating an owner RRef on self, should already exist in owners map - ownerRRef = - fromRRefInterface(ctx.getOwnerRRef(rrefId, /* forceCreated */ true) - ->constValue() - .toRRef()); - } else { - ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, PyObjectType::get()); - } - // NOLINTNEXTLINE(clang-diagnostic-unused-variable) - auto& pythonRpcHandler = PythonRpcHandler::getInstance(); - - if (rrefId != forkId) { - // Caller is a user and callee is the owner, add fork - // - // NB: rrefId == forkId is true if and only if calling remote to self. - // In that case both the caller and the callee will access the - // OwnerRRef. Hence, on the callee side (here), it should not call - // addForkOfOwner as it is not a fork. To allow callee to distinguish - // when this request is sent to self, the caller will set forkId using - // rrefId (OwnerRRef does not have a forkId anyway). - ctx.addForkOfOwner(rrefId, forkId); - } - - try { - processPythonExecution( - uprc.pythonUdf(), - messageId, - responseFuture, - uprc.isAsyncExecution(), - [ownerRRef, rrefId, forkId, markComplete, lsctx = std::move(lsctx)]( - const py::object& result, - const int64_t messageId, - PythonRpcHandler& /* unused */, - const c10::intrusive_ptr& responseFuture) { - // Check we have GIL. - DCHECK(PyGILState_Check()); - - IValue py_ivalue = jit::toIValue(result, PyObjectType::get()); - - py::gil_scoped_release release; - ownerRRef->recordAllStreams(lsctx); - ownerRRef->setValue(std::move(py_ivalue)); - auto m = RemoteRet(rrefId, forkId).toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - }); - } catch (py::error_already_set& e) { - // py::error_already_set requires GIL to destruct, take special care. - ownerRRef->setError(std::current_exception()); - py::gil_scoped_acquire acquire; - e.restore(); - PyErr_Clear(); - } catch (std::exception& e) { - ownerRRef->setError(std::current_exception()); - markComplete(RemoteRet(rrefId, forkId).toMessage()); - } + return assignOwnerRRef( + uprc.rrefId(), uprc.forkId(), std::move(future), std::move(lsctx)); } -void RequestCallbackImpl::processPythonRRefFetchCall( +c10::intrusive_ptr RequestCallbackImpl::processPythonRRefFetchCall( RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr lsctx) const { - // Making this lambda mutable to allow move-capture it in callbacks - auto postProcessing = [responseFuture, lsctx = std::move(lsctx)]( - const c10::intrusive_ptr& rref, - int64_t messageId) mutable { - auto whenValueSet = rref->getFuture(); - if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->exception_ptr()); - return; - } - try { - auto& pythonRpcHandler = PythonRpcHandler::getInstance(); - std::shared_ptr result; - { - // Need this GIL to guard jit::toPyObj and destruct its returned - // py::object - py::gil_scoped_acquire acquire; - result = std::make_shared( - pythonRpcHandler.serialize(jit::toPyObject(rref->getValue()))); - } - Message m = - PythonRRefFetchRet(std::move(*result).toIValues()).toMessage(); - m.setId(messageId); - rref->blockAllStreams(lsctx); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - } catch (py::error_already_set& e) { - // py::error_already_set requires GIL to destruct, take special care. - responseFuture->setError( - std::make_exception_ptr(std::runtime_error(e.what()))); - py::gil_scoped_acquire acquire; - e.restore(); - PyErr_Clear(); - } catch (const std::exception& /* unused */) { - responseFuture->setError(std::current_exception()); - } - }; - auto& prf = static_cast(rpc); - auto& ctx = RRefContext::getInstance(); - - auto futureOwner = ctx.getOwnerRRef(prf.rrefId()); - if (futureOwner->completed()) { - auto rref = fromRRefInterface(futureOwner->constValue().toRRef()); - if (rref->hasValue()) { - // optional fast-path, the OwnerRRef has been created - postProcessing(rref, messageId); - return; - } - } - futureOwner->addCallback( - [messageId, postProcessing{std::move(postProcessing)}]( - JitFuture& futureOwner) mutable { - const auto& rref = fromRRefInterface(futureOwner.constValue().toRRef()); - - // Our response is satisfied when the the rpc.remote() request - // finishes executing on the owner. - rref->getFuture()->addCallback( - [messageId, rref, postProcessing{std::move(postProcessing)}]( - JitFuture& /* unused */) mutable { - postProcessing(rref, messageId); - }); - }); + auto future = retrieveOwnerRRef(prf.rrefId(), std::move(lsctx)); + + return future->then( + [](JitFuture& future) { + SerializedPyObj result = serializePyObject(future.value()); + return c10::make_intrusive( + PythonRRefFetchRet(std::move(result).toIValues()).toMessage()); + }, + c10::getCustomClassType>()); } void RequestCallbackImpl::handleRRefDelete( @@ -462,16 +244,15 @@ void RequestCallbackImpl::handleRRefDelete( } } -void RequestCallbackImpl::processRpcWithErrors( +c10::intrusive_ptr RequestCallbackImpl::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const { try { - processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx)); + return processRpc(rpc, messageType, std::move(ctx)); } catch (py::error_already_set& e) { - responseFuture->markCompleted(handleError(e, messageType, messageId)); + // Pass a dummy message ID since it will be overwritten anyways. + auto future = asFuture(handleError(e, messageType, -1)); // There are request callback impls in Python, where Python // exceptions could be thrown. For releasing Python exception // py::objects, GIL must be held. @@ -480,8 +261,10 @@ void RequestCallbackImpl::processRpcWithErrors( // Python Error Indicator. PyErr_Clear(); // Clear the Python Error Indicator as we has // recorded the exception in the response message. + return future; } catch (std::exception& e) { - responseFuture->markCompleted(handleError(e, messageType, messageId)); + // Pass a dummy message ID since it will be overwritten anyways. + return asFuture(handleError(e, messageType, -1)); } } @@ -493,46 +276,56 @@ bool RequestCallbackImpl::cudaAvailable() const { #endif } -void RequestCallbackImpl::processRRefBackward( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { +c10::intrusive_ptr RequestCallbackImpl::processRRefBackward( + RpcCommandBase& rpc) const { auto& rrefBackwardReq = static_cast(rpc); - // Get all fields - const auto& rrefId = rrefBackwardReq.getRRefId(); - const auto& autogradContextId = rrefBackwardReq.getAutogradContextId(); - const auto& retainGraph = rrefBackwardReq.retainGraph(); - - auto futureOwner = RRefContext::getInstance().getOwnerRRef(rrefId); - futureOwner->addCallback( - [responseFuture, messageId, autogradContextId, retainGraph]( - JitFuture& futureOwner) { - const auto& rref = fromRRefInterface(futureOwner.constValue().toRRef()); - auto whenValueSet = rref->getFuture(); - - whenValueSet->addCallback( - [responseFuture, messageId, rref, autogradContextId, retainGraph]( - JitFuture& whenValueSet) { - if (whenValueSet.hasError()) { - responseFuture->setError(whenValueSet.exception_ptr()); - return; - } - - try { - // Run backward (TODO: make this async?). - PyRRef::backward(autogradContextId, retainGraph, rref); - - // Return the response. - Message m = RRefBackwardResp().toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - } catch (const std::exception& /* unused */) { - responseFuture->setError(std::current_exception()); - } - }); - }); + auto future = + retrieveOwnerRRef(rrefBackwardReq.getRRefId(), /*lsctx=*/nullptr); + + return future->then( + [autogradContextId = rrefBackwardReq.getAutogradContextId(), + retainGraph = rrefBackwardReq.retainGraph()](JitFuture& future) { + // Run backward (TODO: make this async?). + PyRRef::backwardOwnerRRef( + autogradContextId, retainGraph, future.value()); + + return c10::make_intrusive(RRefBackwardResp().toMessage()); + }, + c10::getCustomClassType>()); +} + +c10::intrusive_ptr RequestCallbackImpl::runJitFunction( + const c10::QualifiedName& name, + std::vector& stack, + bool isAsyncExecution) const { + c10::intrusive_ptr future; + try { + // runAsync() starts in the calling thread, but may return an uncompleted + // future (though for non-async code, it will typically be completed). + // If it was async, our callback will typically be invoked by the + // continuation on an at::launch() thread. + future = PythonRpcHandler::getInstance() + .jitCompilationUnit() + ->get_function(name) + .runAsync(stack); + } catch (const std::exception&) { + return asFuture(std::current_exception()); + } + + if (isAsyncExecution) { + at::TypePtr type = future->elementType(); + if (type->kind() != at::FutureType::Kind) { + return asFuture(std::make_exception_ptr(std::runtime_error(c10::str( + "Async functions must return an IValue of Future type, but got ", + type->str())))); + } + future = future->thenAsync( + [](JitFuture& future) { return future.value().toFuture(); }, + type->cast()->getElementType()); + } + + return future; } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h index 6c7c581bee17b1..e74425ddf80a27 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.h +++ b/torch/csrc/distributed/rpc/request_callback_impl.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace torch { namespace distributed { @@ -14,55 +15,45 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { std::unique_ptr rpc, const MessageType& messageType) const override; - void processPythonCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const override; + c10::intrusive_ptr processPythonCall( + RpcCommandBase& rpc) const override; - void processScriptCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const override; + c10::intrusive_ptr processScriptCall( + RpcCommandBase& rpc) const override; - TypePtr getScriptRemoteCallType( - ScriptRemoteCall& scriptRemoteCall) const override; + c10::intrusive_ptr processScriptRemoteCall( + RpcCommandBase& rpc) const override; - void processScriptRemoteCall( - ScriptRemoteCall& scriptRemoteCall, - const std::function& postProcessing, - std::vector& stack, - const c10::intrusive_ptr& ownerRRef) const override; - - void processPythonRemoteCall( + c10::intrusive_ptr processPythonRemoteCall( RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const override; - void processPythonRRefFetchCall( + c10::intrusive_ptr processPythonRRefFetchCall( RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const override; void handleRRefDelete(c10::intrusive_ptr& rref) const override; - void processRpcWithErrors( + c10::intrusive_ptr processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const override; bool cudaAvailable() const override; - void processRRefBackward( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const override; + c10::intrusive_ptr processRRefBackward( + RpcCommandBase& rpc) const override; + + // Helpers to run user-defined functions, operators and other computations. + + c10::intrusive_ptr runJitFunction( + const c10::QualifiedName& name, + std::vector& stack, + bool isAsyncExecution) const; + + c10::intrusive_ptr runPythonFunction( + const py::object& function, + bool isAsyncExecution) const; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index a1ab76e7a71938..80ee0dd3e77606 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -57,7 +57,6 @@ c10::intrusive_ptr RequestCallbackNoPython::processMessage( // RPC message: // 1) waiting for all RRefs in the arguments to become confirmed; // 2) waiting for processRpc to finish. - auto retFuture = c10::make_intrusive(at::AnyClassType::get()); auto& rrefContext = RRefContext::getInstance(); try { rrefContext.recordThreadLocalPendingRRefs(); @@ -66,14 +65,12 @@ c10::intrusive_ptr RequestCallbackNoPython::processMessage( deserializeRequest(request), request.type()); auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs(); - rrefsReadyFuture->addCallback( + auto retFuture = rrefsReadyFuture->thenAsync( [this, - retFuture, // std::function must be copyable, hence hae to cast the unique_ptr to // a shared_ptr here. rpc = (std::shared_ptr)std::move(rpc), messageType = request.type(), - id = request.id(), ctx = std::move(ctx)](JitFuture& /* unused */) mutable { c10::MultiStreamGuard guard( ctx ? ctx->getReservedStreams() : ArrayRef({})); @@ -92,8 +89,8 @@ c10::intrusive_ptr RequestCallbackNoPython::processMessage( ->config()); } - processRpcWithErrors( - *rpc, messageType, id, retFuture, std::move(ctx)); + auto retFuture = + processRpcWithErrors(*rpc, messageType, std::move(ctx)); // Response message has been sent at this moment, this post-response // work doesn't affect RPC trip time. @@ -106,125 +103,73 @@ c10::intrusive_ptr RequestCallbackNoPython::processMessage( profiler::processglobal::pushResultRecursive( serverProcessGlobalProfilerStateStackEntryPtr, event_lists); } - }); + + return retFuture; + }, + c10::getCustomClassType>()); + + auto retFutureWithMessageId = retFuture->then( + [id = request.id()](JitFuture& future) { + c10::intrusive_ptr message = + future.value().toCustomClass(); + message->setId(id); + return message; + }, + c10::getCustomClassType>()); + + return retFutureWithMessageId; } catch (std::exception& e) { - retFuture->markCompleted(handleError(e, request.type(), request.id())); rrefContext.clearRecordedPendingRRefsOnError(); + return asFuture(handleError(e, request.type(), request.id())); } - return retFuture; } -void RequestCallbackNoPython::processRpcWithErrors( +c10::intrusive_ptr RequestCallbackNoPython::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const { try { - processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx)); + return processRpc(rpc, messageType, std::move(ctx)); } catch (std::exception& e) { - responseFuture->markCompleted(handleError(e, messageType, messageId)); + // Pass a dummy message ID since it will be overwritten anyways. + return asFuture(handleError(e, messageType, -1)); } } -void RequestCallbackNoPython::processScriptCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& /* unused */) const { +c10::intrusive_ptr RequestCallbackNoPython::processScriptCall( + RpcCommandBase& rpc) const { auto& scriptCall = static_cast(rpc); - auto& stack = scriptCall.stackRef(); - TORCH_CHECK( - scriptCall.hasOp(), "Only supports the case where ScriptCall has an op"); - processScriptCallOp(scriptCall, markComplete, stack); -} - -bool RequestCallbackNoPython::processScriptCallOp( - ScriptCall& scriptCall, - const std::function& markComplete, - std::vector& stack) const { - if (scriptCall.hasOp()) { - scriptCall.op()->getOperation()(&stack); - TORCH_INTERNAL_ASSERT( - stack.size() == 1, - "Return value of a builtin operator or a " - "TorchScript function should be a single IValue, got a vector of " - "size ", - stack.size()); - markComplete(std::move(ScriptResp(std::move(stack.front()))).toMessage()); - return true; - } - return false; -} -TypePtr RequestCallbackNoPython::getScriptRemoteCallType( - ScriptRemoteCall& scriptRemoteCall) const { TORCH_CHECK( - scriptRemoteCall.hasOp(), - "Only supports the case where ScriptCall has an op"); - return scriptRemoteCall.op()->schema().returns()[0].type(); + scriptCall.hasOp(), "Only supports the case where ScriptCall has an op"); + auto future = runJitOperator(*scriptCall.op(), scriptCall.stackRef()); + + return future->then( + [](JitFuture& future) { + return c10::make_intrusive( + ScriptResp(future.value()).toMessage()); + }, + c10::getCustomClassType>()); } -void RequestCallbackNoPython::processPythonCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& /* unused */) const { +c10::intrusive_ptr RequestCallbackNoPython::processPythonCall( + RpcCommandBase& rpc) const { C10_THROW_ERROR(Error, "Python call not supported!"); } -void RequestCallbackNoPython::processPythonRemoteCall( +c10::intrusive_ptr RequestCallbackNoPython::processPythonRemoteCall( RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& /* unused */, std::shared_ptr /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } -void RequestCallbackNoPython::processScriptRemoteCall( - ScriptRemoteCall& scriptRemoteCall, - const std::function& postProcessing, - std::vector& stack, - const c10::intrusive_ptr& ownerRRef) const { - TORCH_CHECK( - scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!"); - processScriptRemoteCallOp(scriptRemoteCall, postProcessing, stack, ownerRRef); -} - -void RequestCallbackNoPython::processBaseScriptRemoteCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { - auto& scriptRemoteCall = static_cast(rpc); - auto rrefId = scriptRemoteCall.retRRefId(); - auto forkId = scriptRemoteCall.retForkId(); +c10::intrusive_ptr RequestCallbackNoPython::assignOwnerRRef( + const RRefId& rrefId, + const RRefId& forkId, + c10::intrusive_ptr valueFuture, + std::shared_ptr lsctx) const { auto& ctx = RRefContext::getInstance(); - auto postProcessing = [rrefId, forkId, messageId, responseFuture]() { - if (rrefId != forkId) { - // Caller is a user and callee is the owner, add fork - // - // NB: rrefId == forkId is true if and only if calling remote to - // self. In that case both the caller and the callee will access - // the OwnerRRef. Hence, on the callee side (here), it should not - // call addForkOfOwner as it is not a fork. To allow callee to - // distinguish when this request is sent to self, the caller will - // set forkId using rrefId (OwnerRRef does not have a forkId - // anyway). - RRefContext::getInstance().addForkOfOwner(rrefId, forkId); - } - Message m = RemoteRet(rrefId, forkId).toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - }; - - // scriptRemoteCall is only alive within this block, use reference to - // avoid copy. If the underlying code runs with a continuation, runAsync() - // below will std::move the appropriate portion of the stack. - TypePtr returnType = getScriptRemoteCallType(scriptRemoteCall); c10::intrusive_ptr ownerRRef; if (rrefId == forkId) { // Creating an owner RRef on self, should already exist in owners map @@ -233,100 +178,103 @@ void RequestCallbackNoPython::processBaseScriptRemoteCall( ->constValue() .toRRef()); } else { - ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, returnType); + ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, valueFuture->elementType()); + // Caller is a user and callee is the owner, add fork + // + // NB: rrefId == forkId is true if and only if calling remote to self. + // In that case both the caller and the callee will access the + // OwnerRRef. Hence, on the callee side (here), it should not call + // addForkOfOwner as it is not a fork. To allow callee to distinguish + // when this request is sent to self, the caller will set forkId using + // rrefId (OwnerRRef does not have a forkId anyway). + ctx.addForkOfOwner(rrefId, forkId); } - auto& stack = scriptRemoteCall.stackRef(); - processScriptRemoteCall(scriptRemoteCall, postProcessing, stack, ownerRRef); + return valueFuture->then( + [ownerRRef, rrefId, forkId, lsctx = std::move(lsctx)](JitFuture& future) { + if (future.hasError()) { + ownerRRef->setError(future.exception_ptr()); + } else { + ownerRRef->recordAllStreams(lsctx); + ownerRRef->setValue(future.value()); + } + return c10::make_intrusive( + RemoteRet(rrefId, forkId).toMessage()); + }, + c10::getCustomClassType>()); } -bool RequestCallbackNoPython::processScriptRemoteCallOp( - ScriptRemoteCall& scriptRemoteCall, - const std::function& postProcessing, - std::vector& stack, - const c10::intrusive_ptr& ownerRRef) const { - if (scriptRemoteCall.hasOp()) { - try { - scriptRemoteCall.op()->getOperation()(&stack); - } catch (const std::exception& e) { - // Don't throw in this call, but rather transfer the exception - // to the rref. - ownerRRef->setError(std::current_exception()); - postProcessing(); - return true; - } - TORCH_INTERNAL_ASSERT( - stack.size() == 1, - "Return value of a builtin operator or a " - "TorchScript function should be a single IValue, got a vector of " - "size ", - stack.size()); - ownerRRef->setValue(std::move(stack.front())); - postProcessing(); - return true; - } - return false; +c10::intrusive_ptr RequestCallbackNoPython::processScriptRemoteCall( + RpcCommandBase& rpc) const { + auto& scriptRemoteCall = static_cast(rpc); + + TORCH_CHECK( + scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!"); + auto future = + runJitOperator(*scriptRemoteCall.op(), scriptRemoteCall.stackRef()); + + return assignOwnerRRef( + scriptRemoteCall.retRRefId(), + scriptRemoteCall.retForkId(), + std::move(future), + /*lsctx=*/nullptr); } -void RequestCallbackNoPython::processScriptRRefFetchCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { - auto& srf = static_cast(rpc); +c10::intrusive_ptr RequestCallbackNoPython::retrieveOwnerRRef( + const RRefId& rrefId, + std::shared_ptr lsctx) const { auto& ctx = RRefContext::getInstance(); - auto futureOwner = ctx.getOwnerRRef(srf.rrefId()); + auto rrefFuture = ctx.getOwnerRRef(rrefId); + + at::TypePtr type = rrefFuture->elementType(); + TORCH_INTERNAL_ASSERT(type->kind() == at::RRefType::Kind); + return rrefFuture->thenAsync( + [lsctx](JitFuture& rrefFuture) { + c10::intrusive_ptr rref = + fromRRefInterface(rrefFuture.value().toRRef()); + auto valueFuture = rref->getFuture(); + // FIXME This is a temporary fix to synchronize CUDA streams. Once the + // OwnerRRef's internal Future becomes CUDA-aware this will be automatic + // and we can remove this hack. + return valueFuture->then( + [rref, lsctx](JitFuture& future) { + rref->blockAllStreams(lsctx); + return future.value(); + }, + valueFuture->elementType()); + }, + type->cast()->getElementType()); +} - if (futureOwner->completed()) { // optional fast-path - // the OwnerRRef has been created - const auto& rref = fromRRefInterface(futureOwner->constValue().toRRef()); - if (rref->hasValue()) { - markComplete(ScriptRRefFetchRet({rref->getValue()}).toMessage()); - return; - } - } +c10::intrusive_ptr RequestCallbackNoPython:: + processScriptRRefFetchCall(RpcCommandBase& rpc) const { + auto& srf = static_cast(rpc); - futureOwner->addCallback([responseFuture, messageId](JitFuture& futureOwner) { - const auto& rref = fromRRefInterface(futureOwner.constValue().toRRef()); - auto whenValueSet = rref->getFuture(); - - // Our response is satisfied when the rpc.remote() request - // finishes executing on the owner. - whenValueSet->addCallback( - [responseFuture, messageId, rref](JitFuture& whenValueSet) { - if (whenValueSet.hasError()) { - responseFuture->setError(whenValueSet.exception_ptr()); - return; - } - try { - Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - } catch (const std::exception& /* unused */) { - responseFuture->setError(std::current_exception()); - } - }); - }); + auto future = retrieveOwnerRRef(srf.rrefId(), /*lsctx=*/nullptr); + + return future->then( + [](JitFuture& future) { + return c10::make_intrusive( + ScriptRRefFetchRet({future.value()}).toMessage()); + }, + c10::getCustomClassType>()); } -void RequestCallbackNoPython::processPythonRRefFetchCall( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& /* unused */, - std::shared_ptr /* unused */) const { +c10::intrusive_ptr RequestCallbackNoPython:: + processPythonRRefFetchCall( + RpcCommandBase& rpc, + std::shared_ptr /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } -void RequestCallbackNoPython::processRRefUserDelete( - RpcCommandBase& rpc, - const std::function& markComplete) const { +c10::intrusive_ptr RequestCallbackNoPython::processRRefUserDelete( + RpcCommandBase& rpc) const { auto& rud = static_cast(rpc); auto& ctx = RRefContext::getInstance(); auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId()); handleRRefDelete(deletedRRef); - markComplete(std::move(RRefAck()).toMessage()); + return asFuture(RRefAck().toMessage()); } void RequestCallbackNoPython::handleRRefDelete( @@ -334,29 +282,26 @@ void RequestCallbackNoPython::handleRRefDelete( TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!"); } -void RequestCallbackNoPython::processRRefChildAccept( - RpcCommandBase& rpc, - const std::function& markComplete) const { +c10::intrusive_ptr RequestCallbackNoPython::processRRefChildAccept( + RpcCommandBase& rpc) const { auto& rca = static_cast(rpc); auto& ctx = RRefContext::getInstance(); ctx.delPendingChild(rca.forkId()); - markComplete(std::move(RRefAck()).toMessage()); + return asFuture(RRefAck().toMessage()); } -void RequestCallbackNoPython::processRRefForkRequest( - RpcCommandBase& rpc, - const std::function& markComplete) const { +c10::intrusive_ptr RequestCallbackNoPython::processRRefForkRequest( + RpcCommandBase& rpc) const { auto& rfr = static_cast(rpc); auto& ctx = RRefContext::getInstance(); ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId()); - markComplete(RRefAck().toMessage()); + return asFuture(RRefAck().toMessage()); } -void RequestCallbackNoPython::processForwardAutogradReq( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, - std::shared_ptr ctx) const { +c10::intrusive_ptr RequestCallbackNoPython:: + processForwardAutogradReq( + RpcCommandBase& rpc, + std::shared_ptr ctx) const { auto& rpcWithAutograd = static_cast(rpc); // Need to reverse the device map for the backward pass of distributed @@ -385,27 +330,17 @@ void RequestCallbackNoPython::processForwardAutogradReq( // Process the original RPC. auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); - // Make an overall future for the wrapped response. - auto wrappedRpcResponseFuture = - c10::make_intrusive(at::AnyClassType::get()); // Kick off processing for the nested RPC command. // wrappedRpcResponseFuture will be a Future to the result. - processRpc( - rpcWithAutograd.wrappedRpc(), - wrappedMessageType, - messageId, - wrappedRpcResponseFuture, - std::move(ctx)); + auto wrappedRpcResponseFuture = processRpc( + rpcWithAutograd.wrappedRpc(), wrappedMessageType, std::move(ctx)); auto fromWorkerId = rpcWithAutograd.fromWorkerId(); // The original future needs to be marked as completed when the wrapped // one completes, with the autograd context information wrapped. - wrappedRpcResponseFuture->addCallback( - [responseFuture, - messageId, - fromWorkerId, - ctxId = - autogradContext->contextId()](JitFuture& wrappedRpcResponseFuture) { + auto responseFuture = wrappedRpcResponseFuture->then( + [fromWorkerId, ctxId = autogradContext->contextId()]( + JitFuture& wrappedRpcResponseFuture) { // As this callback can be invoked by a different thread, we have to // make sure that the thread_local states in the previous thread is // correctly propagated. @@ -420,24 +355,23 @@ void RequestCallbackNoPython::processForwardAutogradReq( if (wrappedRpcResponseFuture.hasError()) { // Propagate error to responseFuture if we had one. - responseFuture->setError(wrappedRpcResponseFuture.exception_ptr()); + std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr()); } else { auto msg = getMessageWithAutograd( fromWorkerId, std::move( *wrappedRpcResponseFuture.value().toCustomClass()), MessageType::FORWARD_AUTOGRAD_RESP); - msg.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(msg)))); + return c10::make_intrusive(std::move(msg)); } - }); + }, + c10::getCustomClassType>()); + + return responseFuture; } -void RequestCallbackNoPython::processBackwardAutogradReq( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { +c10::intrusive_ptr RequestCallbackNoPython:: + processBackwardAutogradReq(RpcCommandBase& rpc) const { auto& gradientsCall = static_cast(rpc); const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); @@ -457,21 +391,20 @@ void RequestCallbackNoPython::processBackwardAutogradReq( autogradContext, sendFunction, gradientsCall.retainGraph()); // Our response is satisfied when the rpcs come back. - execFuture->addCallback([responseFuture, messageId](JitFuture& execFuture) { - if (!execFuture.hasError()) { - Message m = std::move(PropagateGradientsResp()).toMessage(); - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - } else { - responseFuture->setError(execFuture.exception_ptr()); - } - }); + return execFuture->then( + [](JitFuture& execFuture) { + if (execFuture.hasError()) { + std::rethrow_exception(execFuture.exception_ptr()); + } else { + return c10::make_intrusive( + PropagateGradientsResp().toMessage()); + } + }, + c10::getCustomClassType>()); } -void RequestCallbackNoPython::processCleanupAutogradContextReq( - RpcCommandBase& rpc, - const std::function& markComplete) const { +c10::intrusive_ptr RequestCallbackNoPython:: + processCleanupAutogradContextReq(RpcCommandBase& rpc) const { auto& cleanupContextReq = static_cast(rpc); auto cleanupContextId = cleanupContextReq.getContextId(); // release the context if it still exists on this thread. We need to @@ -480,13 +413,11 @@ void RequestCallbackNoPython::processCleanupAutogradContextReq( // notified to clean up their context. DistAutogradContainer::getInstance().releaseContextIfPresent( cleanupContextId); - markComplete(std::move(CleanupAutogradContextResp()).toMessage()); + return asFuture(CleanupAutogradContextResp().toMessage()); } -void RequestCallbackNoPython::processRunWithProfilingReq( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const { +c10::intrusive_ptr RequestCallbackNoPython:: + processRunWithProfilingReq(RpcCommandBase& rpc) const { auto& rpcWithProfilingReq = static_cast(rpc); auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType(); auto profilingConfig = rpcWithProfilingReq.getProfilingConfig(); @@ -508,8 +439,6 @@ void RequestCallbackNoPython::processRunWithProfilingReq( this->cudaAvailable(), "Profiler state set to CUDA but CUDA not available."); const auto profilingKeyId = rpcWithProfilingReq.getProfilingId(); - auto wrappedRpcResponseFuture = - c10::make_intrusive(at::AnyClassType::get()); // Enable the profiler with the config from the sender. // When enabling on the main thread, ensure profiler states are cleaned // up, but defer consolidation of all profiled events to the continuation @@ -524,16 +453,14 @@ void RequestCallbackNoPython::processRunWithProfilingReq( "Expected profiler to be enabled!"); // Kick off processing for nested work and get Future result in // wrappedRpcResponseFuture - processRpc( + auto wrappedRpcResponseFuture = processRpc( rpcWithProfilingReq.wrappedRpc(), wrappedMsgType, - messageId, - wrappedRpcResponseFuture, {}); // TODO: https://github.com/pytorch/pytorch/issues/55757 - wrappedRpcResponseFuture->addCallback(at::wrapPropagateTLSState( - [responseFuture, profilingKeyId, profilingConfig]( - JitFuture& wrappedRpcResponseFuture) { + auto responseFuture = wrappedRpcResponseFuture->then( + at::wrapPropagateTLSState([profilingKeyId, profilingConfig]( + JitFuture& wrappedRpcResponseFuture) { std::vector profiledEvents; // Defer consolidation of profiler events until async work has // completed (such as async UDF) @@ -551,7 +478,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq( if (wrappedRpcResponseFuture.hasError()) { // Propagate error // No need to propagate remote events in the case of an error. - responseFuture->setError(wrappedRpcResponseFuture.exception_ptr()); + std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr()); } else { populateRemoteProfiledEvents( profiledEvents, profilingConfig, event_lists); @@ -561,33 +488,27 @@ void RequestCallbackNoPython::processRunWithProfilingReq( *wrappedRpcResponseFuture.value().toCustomClass()), profiledEvents, profilingKeyId); - responseFuture->markCompleted(IValue(c10::make_intrusive( - std::move(*rpcWithProfilingResp).toMessage()))); + return c10::make_intrusive( + std::move(*rpcWithProfilingResp).toMessage()); } - })); + }), + c10::getCustomClassType>()); + + return responseFuture; // Exiting the scope will disable the profiler on this thread with the // options specified above. } } -void RequestCallbackNoPython::processRRefBackward( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& /* unused */) const { +c10::intrusive_ptr RequestCallbackNoPython::processRRefBackward( + RpcCommandBase& rpc) const { C10_THROW_ERROR(Error, "Python call not supported!"); } -void RequestCallbackNoPython::processRpc( +c10::intrusive_ptr RequestCallbackNoPython::processRpc( RpcCommandBase& rpc, const MessageType& messageType, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const { - auto markComplete = [messageId, &responseFuture](Message m) { - m.setId(messageId); - responseFuture->markCompleted( - IValue(c10::make_intrusive(std::move(m)))); - }; // TODO: RpcCommandBase should have an abstract execute() method that we can // call here instead of having another switch statement here. Even better we // could have abstract classes RpcRequest and RpcResp which inherit from @@ -596,62 +517,46 @@ void RequestCallbackNoPython::processRpc( // to a python object. switch (messageType) { case MessageType::SCRIPT_CALL: { - processScriptCall(rpc, markComplete, messageId, responseFuture); - return; + return processScriptCall(rpc); } case MessageType::PYTHON_CALL: { - processPythonCall(rpc, markComplete, messageId, responseFuture); - return; + return processPythonCall(rpc); } case MessageType::SCRIPT_REMOTE_CALL: { - processBaseScriptRemoteCall(rpc, markComplete, messageId, responseFuture); - return; + return processScriptRemoteCall(rpc); } case MessageType::PYTHON_REMOTE_CALL: { - processPythonRemoteCall( - rpc, markComplete, messageId, responseFuture, std::move(ctx)); - return; + return processPythonRemoteCall(rpc, std::move(ctx)); } case MessageType::SCRIPT_RREF_FETCH_CALL: { - processScriptRRefFetchCall(rpc, markComplete, messageId, responseFuture); - return; + return processScriptRRefFetchCall(rpc); } case MessageType::PYTHON_RREF_FETCH_CALL: { - processPythonRRefFetchCall( - rpc, messageId, responseFuture, std::move(ctx)); - return; + return processPythonRRefFetchCall(rpc, std::move(ctx)); } case MessageType::RREF_USER_DELETE: { - processRRefUserDelete(rpc, markComplete); - return; + return processRRefUserDelete(rpc); } case MessageType::RREF_CHILD_ACCEPT: { - processRRefChildAccept(rpc, markComplete); - return; + return processRRefChildAccept(rpc); } case MessageType::RREF_FORK_REQUEST: { - processRRefForkRequest(rpc, markComplete); - return; + return processRRefForkRequest(rpc); } case MessageType::FORWARD_AUTOGRAD_REQ: { - processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx)); - return; + return processForwardAutogradReq(rpc, std::move(ctx)); } case MessageType::BACKWARD_AUTOGRAD_REQ: { - processBackwardAutogradReq(rpc, messageId, responseFuture); - return; + return processBackwardAutogradReq(rpc); }; case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { - processCleanupAutogradContextReq(rpc, markComplete); - return; + return processCleanupAutogradContextReq(rpc); } case MessageType::RUN_WITH_PROFILING_REQ: { - processRunWithProfilingReq(rpc, messageId, responseFuture); - return; + return processRunWithProfilingReq(rpc); } case MessageType::RREF_BACKWARD_REQ: { - processRRefBackward(rpc, messageId, responseFuture); - return; + return processRRefBackward(rpc); } default: { TORCH_INTERNAL_ASSERT( @@ -660,7 +565,7 @@ void RequestCallbackNoPython::processRpc( } } -IValue RequestCallbackNoPython::handleError( +c10::intrusive_ptr RequestCallbackNoPython::handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const { @@ -673,8 +578,8 @@ IValue RequestCallbackNoPython::handleError( DistAutogradContainer::getInstance().getWorkerId(), ": ", e.what()); - return IValue(c10::make_intrusive( - createExceptionResponse(errorMsg, messageId))); + return c10::make_intrusive( + createExceptionResponse(errorMsg, messageId)); } bool RequestCallbackNoPython::cudaAvailable() const { @@ -685,6 +590,50 @@ bool RequestCallbackNoPython::cudaAvailable() const { #endif } +c10::intrusive_ptr RequestCallbackNoPython::runJitOperator( + const jit::Operator& op, + std::vector& stack) const { + try { + op.getOperation()(&stack); + } catch (const std::exception&) { + return asFuture(std::current_exception()); + } + TORCH_INTERNAL_ASSERT( + stack.size() == 1, + "Return value of a builtin operator or a TorchScript function should be " + "a single IValue, got a vector of size ", + stack.size()); + TypePtr type = stack.front().type(); + return asFuture(std::move(stack.front()), std::move(type)); +} + +c10::intrusive_ptr RequestCallbackNoPython::asFuture( + IValue value, + TypePtr type) const { + auto future = c10::make_intrusive(std::move(type)); + future->markCompleted(std::move(value)); + return future; +} + +c10::intrusive_ptr RequestCallbackNoPython::asFuture( + c10::intrusive_ptr message) const { + return asFuture( + std::move(message), + at::getCustomClassType>()); +} + +c10::intrusive_ptr RequestCallbackNoPython::asFuture( + Message message) const { + return asFuture(c10::make_intrusive(std::move(message))); +} + +c10::intrusive_ptr RequestCallbackNoPython::asFuture( + std::exception_ptr err) const { + auto future = c10::make_intrusive(at::NoneType::get()); + future->setError(err); + return future; +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h index e01ce9bae55314..4bf1099490d0f8 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.h +++ b/torch/csrc/distributed/rpc/request_callback_no_python.h @@ -23,122 +23,96 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { std::unique_ptr rpc, const MessageType& messageType) const; - virtual void processScriptCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; - - bool processScriptCallOp( - ScriptCall& scriptCall, - const std::function& markComplete, - std::vector& stack) const; - - virtual void processPythonCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; - - virtual TypePtr getScriptRemoteCallType( - ScriptRemoteCall& scriptRemoteCall) const; + virtual c10::intrusive_ptr processScriptCall( + RpcCommandBase& rpc) const; - virtual void processScriptRemoteCall( - ScriptRemoteCall& scriptRemoteCall, - const std::function& postProcessing, - std::vector& stack, - const c10::intrusive_ptr& ownerRRef) const; + virtual c10::intrusive_ptr processPythonCall( + RpcCommandBase& rpc) const; - void processBaseScriptRemoteCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; + c10::intrusive_ptr assignOwnerRRef( + const RRefId& rrefId, + const RRefId& forkId, + c10::intrusive_ptr valueFuture, + std::shared_ptr lsctx) const; - bool processScriptRemoteCallOp( - ScriptRemoteCall& scriptRemoteCall, - const std::function& postProcessing, - std::vector& stack, - const c10::intrusive_ptr& ownerRRef) const; + virtual c10::intrusive_ptr processScriptRemoteCall( + RpcCommandBase& rpc) const; - virtual void processPythonRemoteCall( + virtual c10::intrusive_ptr processPythonRemoteCall( RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const; - void processScriptRRefFetchCall( - RpcCommandBase& rpc, - const std::function& markComplete, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; + c10::intrusive_ptr retrieveOwnerRRef( + const RRefId& rrefId, + std::shared_ptr lsctx) const; + + c10::intrusive_ptr processScriptRRefFetchCall( + RpcCommandBase& rpc) const; - virtual void processPythonRRefFetchCall( + virtual c10::intrusive_ptr processPythonRRefFetchCall( RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const; - void processRRefUserDelete( - RpcCommandBase& rpc, - const std::function& markComplete) const; + c10::intrusive_ptr processRRefUserDelete( + RpcCommandBase& rpc) const; - void processRRefChildAccept( - RpcCommandBase& rpc, - const std::function& markComplete) const; + c10::intrusive_ptr processRRefChildAccept( + RpcCommandBase& rpc) const; - void processRRefForkRequest( - RpcCommandBase& rpc, - const std::function& markComplete) const; + c10::intrusive_ptr processRRefForkRequest( + RpcCommandBase& rpc) const; - void processForwardAutogradReq( + c10::intrusive_ptr processForwardAutogradReq( RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const; - void processBackwardAutogradReq( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; + c10::intrusive_ptr processBackwardAutogradReq( + RpcCommandBase& rpc) const; - void processCleanupAutogradContextReq( - RpcCommandBase& rpc, - const std::function& markComplete) const; + c10::intrusive_ptr processCleanupAutogradContextReq( + RpcCommandBase& rpc) const; - void processRunWithProfilingReq( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; + c10::intrusive_ptr processRunWithProfilingReq( + RpcCommandBase& rpc) const; virtual void handleRRefDelete(c10::intrusive_ptr& rref) const; - void processRpc( + c10::intrusive_ptr processRpc( RpcCommandBase& rpc, const MessageType& messageType, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const; - virtual void processRpcWithErrors( + virtual c10::intrusive_ptr processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture, std::shared_ptr ctx) const; - IValue handleError( + c10::intrusive_ptr handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const; virtual bool cudaAvailable() const; - virtual void processRRefBackward( - RpcCommandBase& rpc, - const int64_t messageId, - const c10::intrusive_ptr& responseFuture) const; + virtual c10::intrusive_ptr processRRefBackward( + RpcCommandBase& rpc) const; + + // Helpers to run user-defined functions, operators and other computations. + + c10::intrusive_ptr runJitOperator( + const jit::Operator& op, + std::vector& stack) const; + + // Helpers to convert various kinds of objects into already-completed futures. + + c10::intrusive_ptr asFuture(IValue value, TypePtr type) const; + + c10::intrusive_ptr asFuture( + c10::intrusive_ptr message) const; + + c10::intrusive_ptr asFuture(Message message) const; + + c10::intrusive_ptr asFuture(std::exception_ptr err) const; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index e574e84c4cbc84..6c637ffb3ac8d5 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -348,7 +348,7 @@ c10::intrusive_ptr RRefContext::getOrCreateOwnerRRef( // since Tensor can only get specialized with a previous run of local // JIT function, and we shouldn't preserve the specialized SubTensorType // information on other workers because it's only information only. - if (type == TensorType::get()) { + if (type->isSubtypeOf(TensorType::get())) { TORCH_INTERNAL_ASSERT( ownerRRef->type()->isSubtypeOf(TensorType::get()), "Expect OwnerRRef to be a sub-type of TensorType, but got ", diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 30af27cad296db..e7eaa6cc3c0b15 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -306,7 +306,7 @@ void OwnerRRef::recordAllStreams( } } -void OwnerRRef::blockAllStreams(std::shared_ptr& ctx) { +void OwnerRRef::blockAllStreams(const std::shared_ptr& ctx) { if (ctx) { for (c10::Event& event : events_) { event.block(ctx->getStream(event.device())); diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index 9d631b7dd67ba7..2d15dcd9266914 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -409,7 +409,7 @@ class TORCH_API OwnerRRef final : public RRef { // Blocks all streams in the context on all events previously stored in // the current OwnerRRef instance. - void blockAllStreams(std::shared_ptr& ctx); + void blockAllStreams(const std::shared_ptr& ctx); private: // a storage for device events for synchronization. diff --git a/torch/csrc/distributed/rpc/script_call.cpp b/torch/csrc/distributed/rpc/script_call.cpp index ab273f7a9ada32..65a214a1d83f93 100644 --- a/torch/csrc/distributed/rpc/script_call.cpp +++ b/torch/csrc/distributed/rpc/script_call.cpp @@ -34,7 +34,7 @@ bool ScriptCall::hasQualifiedName() const { return qualifiedName_ ? true : false; } -const c10::QualifiedName ScriptCall::qualifiedName() const { +const c10::QualifiedName& ScriptCall::qualifiedName() const { return *qualifiedName_; } diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 3427b511911ac0..581b73b38d3f7d 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -31,7 +31,7 @@ class TORCH_API ScriptCall : public RpcCommandBase { bool hasOp() const; std::shared_ptr op() const; bool hasQualifiedName() const; - const c10::QualifiedName qualifiedName() const; + const c10::QualifiedName& qualifiedName() const; // return the argument stack of this builtin operator const std::vector& stack() const; std::vector& stackRef(); diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 877750c37e5a69..03adbc7bfdb5a4 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -217,7 +217,7 @@ static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args) return THPStorage_(New)( THWStorage_(newWithDataAndAllocator)( // TODO: Maybe we should read out the scalar_t size and use it for size - THMapAllocator::makeDataPtr(WITH_FD, nullptr, fd, flags, size * sizeof(scalar_t), nullptr), + THMapAllocator::makeDataPtr(WITH_FD, "", fd, flags, size * sizeof(scalar_t), nullptr), size, /* allocator */ nullptr)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/jit/backends/backend_debug_handler.cpp b/torch/csrc/jit/backends/backend_debug_handler.cpp index 29e7fe0aecd60f..d21e4efd5681d2 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.cpp +++ b/torch/csrc/jit/backends/backend_debug_handler.cpp @@ -3,29 +3,31 @@ namespace torch { namespace jit { +namespace { +thread_local BackendDebugInfoRecorder* debug_info_recorder_ptr{nullptr}; +} // namespace + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -std::atomic BackendDebugHandleManager::unique_debug_handle_{0}; +std::atomic BackendDebugInfoRecorder::unique_debug_handle_{0}; -int64_t BackendDebugHandleManager::getNextDebugHandleForInlinedCallStackPtr( - const Node* node) { +int64_t BackendDebugInfoRecorder::getNextDebugHandle(const Node* node) { InlinedCallStackPtr cs_ptr; if (node->callstack().has_value()) { cs_ptr = node->callstack().value(); } else { cs_ptr = c10::intrusive_ptr(); } - int64_t debug_handle = unique_debug_handle_; + DebugHandleType debug_handle = unique_debug_handle_; const SourceRange& range = node->sourceRange(); handles_to_inlined_callstack_ptrs_[debug_handle] = - std::make_pair(range, cs_ptr); + std::make_tuple(range, node->kind().toQualString(), cs_ptr); // This increment is with seq memory order. // Not trying to perf optimizing this for now. unique_debug_handle_++; return debug_handle; } -std::unordered_map BackendDebugHandleManager:: - getCallStackPtrMap() { +BackendDebugInfoMapType BackendDebugInfoRecorder::stopRecording() { // Note that this is return by copy and since // InlinedCallStackPtrs are intrusive ptr it will result in // bump of refcount. Not performant, but this is not intented @@ -34,5 +36,23 @@ std::unordered_map BackendDebugHandleManager:: return handles_to_inlined_callstack_ptrs_; } +WithBackendDebugInfoRecorder::WithBackendDebugInfoRecorder( + BackendDebugInfoRecorder* recorder) throw() { + TORCH_CHECK( + debug_info_recorder_ptr == nullptr, + "Module debug recording already in progress."); + debug_info_recorder_ptr = recorder; +} + +WithBackendDebugInfoRecorder::~WithBackendDebugInfoRecorder() { + // If due to some exception within preprocess, such as compilation failure + // we throw, then we want to make sure the exit is clean + debug_info_recorder_ptr = nullptr; +} + +BackendDebugInfoRecorder* getBackendDebugInfoRecorder() { + return debug_info_recorder_ptr; +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/backends/backend_debug_handler.h b/torch/csrc/jit/backends/backend_debug_handler.h index 004598665464c3..60727bfcc242ac 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.h +++ b/torch/csrc/jit/backends/backend_debug_handler.h @@ -1,4 +1,6 @@ #pragma once +#include + #include #include @@ -9,82 +11,72 @@ namespace jit { /* * BackendDebugHandleManager is responsible for issuing debug handles to - * backends. Debug handles are associated with nodes of an inlined graph. + * backends. Debug handles are associated with nodes of a graph. * BackendDebugHandleManager also maintains a map - * [debug-handle, {source range, inlined-callstack-ptr}] that will help - * generate a callstack for exception raised using debug handles. + * [debug-handle, DebugInfoTuple = {source range, inlined callstack ptr]} that + * will help generate a callstack for exception raised using debug handles. * Effectively debug handles are something that is given to backend and later * when an exception occurs in the backend, backend can tell, using debug - * handle, that exception occurred here. Then the runtime can generate + * handle, that an exception occurred here. Then the runtime can generate * callstack correspoding to the exception. * There are two parts to BackendDebugHandleManager: * 1. static std::atomic debug_handle - * 2. Map of [debug-handle, {source range, inlined-callstack-ptr}] - * Debug handles + * 2. Map of [debug-handle, DebugInfoTuple] * * About 1: * Why do they have to be unique. The reason is that by ensuring * uniqueness of debug handles, we remove the burden of another layer of * mapping where we need to say this set of debug handles were generated for * this lowered module or this bytecode function. This simplifies the API for - * serialization since debug handles can uniquely identify - * {source range, inlined-callstack-ptr}. Thus simplifies the runtime API for - * throwing exception. Exception throwing only needs to know debug_handle and - * not which module or method threw it. There are 2 issues to keep in mind, - * though,for static std::atomic debug_handle: A. Performance implications of - * using atomic variable. However this is only used for compilation so we - * assume to absorb some of that penalty. Plus if there is no contention then we - * should have less to worry about. B. If repeated compilation is part of a long - * running process then we may overflow int64_t. We may detect and fail on - * this. Not sure if this is good. Will seek opinions about this. + * serialization since debug handles can uniquely identify DebugInfoTuple. + * Thus simplifies the runtime API for throwing exception. Exception throwing + * only needs to know debug_handle and not which module or method threw it. + * There are 2 issues to keep in mind, though,for static std::atomic + * debug_handle: A. Performance implications of using atomic variable. However + * this is only used for compilation so we assume to absorb some of that + * penalty. Plus if there is no contention then we should have less to worry + * about. B. If repeated compilation is part of a long running process then we + * may overflow int64_t. We may detect and fail on this. For now this is not + * done. * * Now about 2: - * There are two usecases for [debug-handle, {source range, - * inlined-callstack-ptr}]: A. During bytecode generation the inlined callstack - * ptrs and source range, corresponding to the nodes of the inlined graph being - * serialized, are stored in this object and a unique debug handle is returned. - * This unique debug handle is stored in mobile_debug info for pytorch lite - * models. It will be used for raising exceptions as well as profiling. B. - * During backend lowering, each backend's preprocess/compile method can - * compile method's graph and serialize those methods. Once the method is - * lowered to backend, graph is essentially lost. Without access to graph it is - * hard to generate model level debug info. Thus the debug handles provide a way - * to map nodes of the graph to the model level debug info. Following diffs will - * provide an API to generate debug handles given inlined graph. This API will, - * given an inlined graph of a method, return a map of node*-to-debug-handles. - * Backends will serialize these debug handles and use them to raise exception, - * much like lite interpreter. Underneath the API we will utilize - * BackendDebugHandleManager, to generate map of [debug-handles, {source range, - * inlined-callstack-ptrs}]. + * There are two usecases for [debug-handle, DebugInfoTuple] + * A. During bytecode generation the DebugInfoTuple corresponding to the nodes + * of the inlined graph being serialized, are stored in this object and a + * unique debug handle is returned. This unique debug handle is stored in + * mobile_debug info for pytorch lite models. It will be used for raising + * exceptions as well as profiling. B. During backend lowering, each backend's + * preprocess/compile method can compile method's graph and serialize those + * methods. Once the method is lowered to backend, graph is essentially lost. + * Without access to graph it is hard to generate model level debug info. Thus + * the debug handles provide a way to map nodes of the graph to the model level + * debug info. * - * During byte-code model serialization, [debug-handle, {source range, - * inlined-callstack-ptrs}] is serialized. Now we know a. debug - * handles and b. how to map debug handles to model source code. Thus we can - * either do eager symbolication by converting debug handles to corresponding - * source code at runtime, or do lazy symbolicattion offline. + * During byte-code model serialization, [debug-handle, DebugInfoTuple] is + * serialized. Now we know a. debug handles and b. how to map debug handles to + * model source code. Thus we can either do eager symbolication by converting + * debug handles to corresponding source code at runtime, or do lazy + * symbolicattion offline. * - * Note that it is not necessary to serialize [debug-handle, {source range, - * inlined-callstack-ptrs}] lowered backend if the lowering - * process, that is preprocess/compile, and execution happens in the same - * session, then eager symbolication can be employed. + * Note that it is not necessary to serialize [debug-handle, DebugInfoTuple] + * corresponding to lowered backend if the lowering process, that is + * preprocess/compile, and execution happens in the same session, then eager + * symbolication can be employed. * * Now how does BackendDebugHandleManager capture all of the above? * By providing two API. - * 1. getNextDebugHandleForInlinedCallStackPtr which given a source range and - * inlined callstack ptr returns a unique debug handle, that will uniquely - * identify the tuple of source range and inlined callstack ptr. + * 1. getNextDebugHandle which given a Node* returns a unique debug handle, + * that will uniquely identify DebugInfoTuple. * and * 2. getCallStackPtrMap which returns the map - * [debug-handle, {source range, inlined-callstack-ptr}]. + * [debug-handle, DebugInfoTuple] * * 1 provides debug handles to backends and 2 provides runtime a way to map * debug handles to source level debug info. * - * So why does debug handle map to {source range and inlined cs}? - * {debug_handle, source_range_tag, serialized_callstack} - * Take this example: - * class L(nn.Module): - * def __init__(self): + * So why does debug handle map to DebugInfoTuple = {source range and inlined + * cs}? {debug_handle, source_range_tag, serialized_callstack} Take this + * example: class L(nn.Module): def __init__(self): * ... * def forward(self, x): * return x * 5 @@ -113,20 +105,51 @@ namespace jit { * mul node's inlined CS contains only information about the callsites' source * range The information about mul node's source range ('return x * 5') is not * available in its inlined CS. It is rather part of node's source range - * instead of inlined CS. Thus to get full stack: [N.forward, source range] -> + * instead of inlined CS. Thus to get full stack: [N.forward, source range] -> * [M.forward, source range] -> [aten::mul's source range] We need to track * mul's source range and inlined CS both. */ -class TORCH_API BackendDebugHandleManager { +using DebugHandleType = int64_t; + +using BackendDebugInfoMapType = + std::unordered_map; + +/* + * This class is used to generate debug info map. + * It instantiates debug_handle_manager and initialize thread local pointer to + * it. backend's preprocess will call generate_debug_handles, which uses + * debug_handle_manager to generate debug handles. When lowering process + * finishes, calling stopRecording will return debug info map from + * debug_handle_manager + */ +class TORCH_API BackendDebugInfoRecorder { public: - BackendDebugHandleManager() = default; - int64_t getNextDebugHandleForInlinedCallStackPtr(const Node* node); - std::unordered_map getCallStackPtrMap(); + BackendDebugInfoRecorder() = default; + + int64_t getNextDebugHandle(const Node* node); + // Reason this is not done as RAII is that work done in stopRecording + // can throw, and throwing with dtor will call terminate and thus voids any + // exception catching at a higher level. + BackendDebugInfoMapType stopRecording(); private: - static std::atomic unique_debug_handle_; - std::unordered_map handles_to_inlined_callstack_ptrs_; + static std::atomic unique_debug_handle_; + BackendDebugInfoMapType handles_to_inlined_callstack_ptrs_; }; +// This is a RAII class that on ctor captures pointer to +// BackendDebugInfoRecorder and initializes thread_local pointer +// debug_info_recorder to it. Upon dtor it sets debug_info_recorder +// pointer back to null. Note that this context manager always requires +// that debug_info_recorder be nullptr when initializing the context. +// This is because nested scopes with debug_info_recorder are not yet allowed. +class WithBackendDebugInfoRecorder { + public: + WithBackendDebugInfoRecorder(BackendDebugInfoRecorder* recorder) throw(); + ~WithBackendDebugInfoRecorder(); +}; + +BackendDebugInfoRecorder* getBackendDebugInfoRecorder(); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/backends/backend_debug_info.cpp b/torch/csrc/jit/backends/backend_debug_info.cpp new file mode 100644 index 00000000000000..5f6fbb6d3f316c --- /dev/null +++ b/torch/csrc/jit/backends/backend_debug_info.cpp @@ -0,0 +1,23 @@ +#include +#include + +namespace torch { +namespace jit { +namespace backend { +namespace { +#ifdef BUILD_LITE_INTERPRETER +static auto cls = torch::class_( + kBackendUtilsNamespace, + kBackendDebugInfoClass) + .def(torch::init<>()); +#else +static auto cls = torch::class_( + kBackendUtilsNamespace, + kBackendDebugInfoClass) + .def(torch::init<>()); +#endif + +} // namespace +} // namespace backend +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/backend_debug_info.h b/torch/csrc/jit/backends/backend_debug_info.h new file mode 100644 index 00000000000000..1d07beb6bdb3c0 --- /dev/null +++ b/torch/csrc/jit/backends/backend_debug_info.h @@ -0,0 +1,65 @@ +#pragma once + +#ifndef BUILD_LITE_INTERPRETER +#include +#endif +#include + +namespace torch { +namespace jit { + +constexpr static auto kBackendUtilsNamespace = "backendutils"; +constexpr static auto kBackendDebugInfoClass = "BackendDebugInfo"; + +#ifndef BUILD_LITE_INTERPRETER +/* + * Custom class for holding debug information in lowered modules, intended + * purely for keeping this information to be later serialized outside of the + * lowered module itself. + * Its usage pattern is: + * 1. LoweredModule declares an instance of this class in __backend_debug_info + * 2. During serialization, __backend_debug_info is used to obtain the debug + * information. + * 3. The contents of LoweredModule.__backend_debug_info are not serialized + * within the LoweredModule itself. + */ +class TORCH_API PyTorchBackendDebugInfo : public torch::CustomClassHolder { + public: + PyTorchBackendDebugInfo() = default; + + c10::optional& getDebugInfoMap() { + return debug_info_map_; + } + + void setDebugInfoMap(BackendDebugInfoMapType&& debug_info_map) { + debug_info_map_ = std::move(debug_info_map); + } + + private: + c10::optional debug_info_map_; +}; + +#else + +/* + * Dummy instance exists for the following reason: + * __backend_debug_info is of type BackendDebugInfo which is a torchbind' + * class backed by cpp class PyTorchBackendDebugInfo. + * PyTorchBackendDebugInfo, depends on ir.h., scope.h, source_range etc. + * We dont include this on lite interpreter side. Thus on lite interpreter side + * we cannot have valid definition of PyTorchBackendDebugInfo. However we do not + * need valid instance of __backend_debug_info in lite interpreter anyway as we + * dont serialize this info as part of LowerdModule as mentioned ealrier. + * However since LoweredModule has registered attribute of __backend_debug_info + * we still need to make sure that BackendDebugInfo is registered with + * TorchScript. However in this instance it does not have to be backed by + * PyTorchBackendDebugInfo, so we create a dummy PyTorchBackendDebugInfoDummy + * just for this purpose. + */ +class PyTorchBackendDebugInfoDummy : public torch::CustomClassHolder { + public: + PyTorchBackendDebugInfoDummy() = default; +}; +#endif +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/backend_detail.cpp b/torch/csrc/jit/backends/backend_detail.cpp index 046260d4adc20f..60cf0358ff8e68 100644 --- a/torch/csrc/jit/backends/backend_detail.cpp +++ b/torch/csrc/jit/backends/backend_detail.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -63,6 +65,12 @@ Module codegen_backend_module( std::make_shared(), /*shouldMangle=*/true); + // 1. Initialized debug info recorder. + // 2. Later call debug_info_recorder.stopRecording() to gather + // recorded debug info and save it in __backend_debug_info. + BackendDebugInfoRecorder debug_info_recorder; + WithBackendDebugInfoRecorder recorder_context(&debug_info_recorder); + // Generate attributes. // This is the preprocessed module. // For backwards compatibility, for backends that implement preprocessing in @@ -124,6 +132,43 @@ Module codegen_backend_module( )", loweredModuleResolver()); + // backend_debug_info_class is an instance of BackendDebugInfo that + // stores debug information. + // The purpose of this class is to make the debug information available + // at model saving time for serializing it outside of the lowered module, + // while still tying it to the module's lifetime (so it gets destroyed along + // with it). + // Whereas this information is not serialized as part of the lowered + // module, we still need to provide a valid instance of the + // BackendDebugInfo class when the lowered module is deserialized. + // Since the deserialized modules does not need this information, + // we create a "dummy" instance with no extra code dependencies (to avoid + // overhead) when the backend is created in __setstate__. + c10::intrusive_ptr backend_debug_info_class; + const c10::QualifiedName backend_debug_info_class_name( + {"__torch__", + "torch", + "classes", + kBackendUtilsNamespace, + kBackendDebugInfoClass}); + auto debug_info_cls = + getCustomClass(backend_debug_info_class_name.qualifiedName()); + TORCH_CHECK(debug_info_cls, "BackendDebugInfo class must be available."); + loweredModule.register_attribute( + "__backend_debug_info", + OptionalType::create(debug_info_cls), + IValue::make_capsule(backend_debug_info_class)); + static const auto create_backend_debug_info_ct = CodeTemplate(R"( + def __create_backend_debug_info(self): + self.__backend_debug_info = $backend_debug_info() + )"); + TemplateEnv create_backend_debug_info_te; + create_backend_debug_info_te.s( + "backend_debug_info", backend_debug_info_class_name.qualifiedName()); + loweredModule.define( + create_backend_debug_info_ct.format(create_backend_debug_info_te), + loweredModuleResolver()); + // getstate and setstate are for serialization/deserialization of // the LoweredModule. // setstate is in charge of initializing self.__backend by invoking @@ -148,6 +193,7 @@ Module codegen_backend_module( # state[2] indicates whether to create the backend instance. if state[2]: self.__create_backend() + self.__create_backend_debug_info() if self.__backend.is_available() : self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec) else: @@ -274,6 +320,14 @@ Module codegen_backend_module( "] is not available. Execution of this Module is still possible by " "saving and loading on a device where the backend is available."); } + + // stop debug info recording and get debug_info_map + auto debug_info_map = debug_info_recorder.stopRecording(); + loweredModule.run_method("__create_backend_debug_info"); + auto backend_debug_info = loweredModule.attr("__backend_debug_info") + .toCustomClass(); + backend_debug_info->setDebugInfoMap(std::move(debug_info_map)); + return loweredModule; } } // namespace detail diff --git a/torch/csrc/jit/backends/backend_exception.h b/torch/csrc/jit/backends/backend_exception.h new file mode 100644 index 00000000000000..bf36fbcb4f7506 --- /dev/null +++ b/torch/csrc/jit/backends/backend_exception.h @@ -0,0 +1,52 @@ +#pragma once +#include + +namespace c10 { +class TORCH_API BackendRuntimeException : public c10::Error { + public: + // Use debug_handle to throw exception + BackendRuntimeException( + SourceLocation loc, + std::string msg, + int64_t debug_handle) + : c10::Error(loc, msg) { + debug_handles.push_back(debug_handle); + } + // If rethrowing, can push another debug_handle + // This is useful in couple of scenarios. + // 1. A submodule is lowered and lite interperter has CallMethod + // to lowered module's method. In this case lowered module will throw with + // a handle, plus there will be another debug handle corresponding + // to the CallMethod node in lite interpreter. Both together give complete + // trace. This function allows lite interpreter to rethrow with debug + // handle it has for CallMethod. + // 2. Another scenarios is when lite interperter can make function calls or + // the lowered backend also has function call ability. Thus we have + // multiple function frames. Now we need a stack of handles to symbolicate + // entire stack trace. + void pushDebugHandle(int64_t debug_handle) { + debug_handles.push_back(debug_handle); + } + const std::vector& getDebugHandles() { + return debug_handles; + } + + private: + // Stores stack of debug handles. + std::vector debug_handles; +}; + +} // namespace c10 +#define TORCH_DELEGATED_BACKEND_THROW(cond, msg, debug_handle) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + throw ::c10::BackendRuntimeException( \ + {__func__, __FILE__, static_cast(__LINE__)}, \ + msg, \ + debug_handle); \ + } + +#define TORCH_DELEGATED_BACKEND_RETHROW(e, debug_handle) \ + do { \ + e.pushDebugHandle(debug_handle); \ + throw; \ + } while (false) diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp index d6f4171f2261fd..deb0b950782d25 100644 --- a/torch/csrc/jit/backends/backend_init.cpp +++ b/torch/csrc/jit/backends/backend_init.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -188,6 +189,11 @@ void initJitBackendBindings(PyObject* module) { throw py::cast_error(c10::str( "Object ", py::str(orig_module), " is not a ScriptModule")); }); + + m.def( + "_jit_backend_generate_debug_handles", [](std::shared_ptr& graph) { + return generate_debug_handles(graph); + }); } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/backends/generate_debug_handles.cpp b/torch/csrc/jit/backends/generate_debug_handles.cpp new file mode 100644 index 00000000000000..70b8384f04f39c --- /dev/null +++ b/torch/csrc/jit/backends/generate_debug_handles.cpp @@ -0,0 +1,39 @@ +#include + +#include + +namespace torch { +namespace jit { + +NodeToDebugHandle generate_debug_handles(const std::shared_ptr& graph) { + NodeToDebugHandle node_to_debug_handles; + auto* debug_info_recorder_ptr = getBackendDebugInfoRecorder(); + + // Note now we make having a valid debug_handle_manager a must. + // This avoids silently failing when say some code change results in + // to_backend not creating appropriate debug_handle_manager to + // be used with backend's preprocess function. + TORCH_CHECK( + debug_info_recorder_ptr, "Valid debug info recorder must be available."); + std::stack blocks_to_visit; + // TODO: Look into using DepthFirstGraphNodeIterator + // At the moment it takes non-const graph but maybe we can make it + // general such that it can work with both. + blocks_to_visit.push(graph->block()); + while (!blocks_to_visit.empty()) { + Block* b = blocks_to_visit.top(); + blocks_to_visit.pop(); + for (Node* n : b->nodes()) { + DebugHandleType debug_handle = + debug_info_recorder_ptr->getNextDebugHandle(n); + node_to_debug_handles.emplace(n, debug_handle); + for (Block* subblock : n->blocks()) { + blocks_to_visit.push(subblock); + } + } + } + return node_to_debug_handles; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/generate_debug_handles.h b/torch/csrc/jit/backends/generate_debug_handles.h new file mode 100644 index 00000000000000..a2205093c9f0e1 --- /dev/null +++ b/torch/csrc/jit/backends/generate_debug_handles.h @@ -0,0 +1,39 @@ +#pragma once +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +using NodeToDebugHandle = std::unordered_map; + +/* + * This is the API via which backend's preprocess function will obtain debug + * handles corresponding to the nodes of the graph for the lowered methods of + * the module. It is expected that the graphs of the methods are inlined. If + * graph is not inlined, this method will throw exception. Implementation: Given + * moudle with inlined methods: + * 1. Query if a valid debug handle manager has been initialized + * 2. If so use debug handle manager to generate debug handles, else all handles + * are -1. -1 is not quite the great constant for invalid handle, so we will + * probably fix this later. This will be used to generate debug handles and + * debug info map: + * 1. Inside to_backend, use BackendModuleDebugInfoRecorder to initialize thread + * local debug handler context. for the lowered module ptr. + * 2. Backend code for lowering module, preprocess, calls + * generate_debug_handles(graph)) which will return debug handles + * corresponding to the Node* of the said graph. + * 3. In to_backend, after lowering, call stopRecording on + * BackendModuleDebugInfoRecorder: It will extract debug map. This map gets + * stored in static instance of ModuleDebugInfoMap. Now there is a global map in + * which module's callstack ptr maps are stored and can be queried during + * serialization. + */ +NodeToDebugHandle TORCH_API +generate_debug_handles(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 27d25ab9622cc5..d2a07563d7ea3b 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -4216,20 +4216,29 @@ struct to_ir { const SliceExpr& slice = SliceExpr(subscript_exprs[0]); std::vector> tuple_args; tuple_args.reserve(3); - auto begin = - NamedValue(val_range, "begin", emitExpr(Expr(slice.startOr(0)))); - tuple_args.emplace_back(begin); + if (slice.start().present()) { + auto begin = NamedValue( + val_range, "begin", emitExpr(Expr(slice.start().get()))); + tuple_args.emplace_back(begin); + } else { + tuple_args.emplace_back(c10::nullopt); + } + if (slice.end().present()) { auto end = NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); tuple_args.emplace_back(end); - } else { tuple_args.emplace_back(c10::nullopt); } - // pushing step_size to match the tuple_args - tuple_args.emplace_back(c10::nullopt); + if (slice.step().present()) { + auto step = + NamedValue(val_range, "step", emitExpr(Expr(slice.step().get()))); + tuple_args.emplace_back(step); + } else { + tuple_args.emplace_back(c10::nullopt); + } auto tupleSliceValue = emitTupleSlice(val_range, s_tuple_val, tuple_args); return std::make_shared(tupleSliceValue); diff --git a/torch/csrc/jit/frontend/source_ref.h b/torch/csrc/jit/frontend/source_ref.h index cb0ab86e5c467e..fd7aa0f3be7db7 100644 --- a/torch/csrc/jit/frontend/source_ref.h +++ b/torch/csrc/jit/frontend/source_ref.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -18,7 +19,7 @@ namespace jit { * support heteogeneous lookup, and also shared_ptr is an implementation detail * which should be encapsulated. */ -class TORCH_API SourceRef { +class TORCH_API SourceRef : public CustomClassHolder { public: explicit SourceRef(std::shared_ptr source) : source_(std::move(source)) {} @@ -34,6 +35,9 @@ class TORCH_API SourceRef { bool operator<(const SourceRef& other) const { return *this < *other.source_.get(); } + const Source* operator->() const { + return source_.get(); + } private: std::shared_ptr source_; diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index c835003d7834d9..efa50f93881b2c 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -630,13 +630,13 @@ void addInputs( const c10::optional& value) { detail::genericAddOptionalInput(n, name, value); } -void addInputs(Node* n, const char* name, const std::string& value) { - detail::genericAddInput(n, value); +void addInputs(Node* n, const char* name, const c10::string_view value) { + detail::genericAddInput(n, std::string(value)); } void addInputs( Node* n, const char* name, - const c10::optional& value) { + const c10::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs(Node* n, const char* name, const at::Tensor& value) { diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index 434a4418b08413..ac94e290de5d28 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -280,11 +280,14 @@ TORCH_API void addInputs( Node* n, const char* name, const c10::optional>& value); -TORCH_API void addInputs(Node* n, const char* name, const std::string& value); TORCH_API void addInputs( Node* n, const char* name, - const c10::optional& value); + const c10::string_view value); +TORCH_API void addInputs( + Node* n, + const char* name, + const c10::optional& value); TORCH_API void addInputs(Node* n, const char* name, at::Device value); TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream); TORCH_API void addInputs(Node* n, const char* name, at::Layout value); diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index b9bdb6ff5ea2c5..4aa3061926de8d 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -872,6 +872,24 @@ void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) { uses_.end()); } +void Value::replaceAllUsesDominatedByNodeWith( + const Node* node, + Value* newValue) { + std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) { + if (u.user->isDominatedBy(node)) { + u.user->inputs_[u.offset] = newValue; + newValue->uses_.push_back(u); + } + }); + + uses_.erase( + std::remove_if( + uses_.begin(), + uses_.end(), + [&node](const Use& u) { return u.user->isDominatedBy(node); }), + uses_.end()); +} + size_t findArgument( const FunctionSchema& the_schema, const std::string& unqualName) { @@ -1300,6 +1318,17 @@ Node* Node::replaceWithNewSymbol(Symbol new_symbol) { return replace_node; } +bool Node::isDominatedBy(const Node* dominator) const { + const Node* node = this; + while (node) { + if (node->owningBlock() == dominator->owningBlock()) { + return dominator->isBefore(node); + } + node = node->owningBlock()->owningNode(); + } + return false; +} + Value* Node::insertInput(size_t i, Value* value) { AT_ASSERT(graph_ == value->owningGraph()); op_ = nullptr; diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index d7831c86a0d620..f9ba8f478b23c9 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -266,8 +266,26 @@ struct Value { // %4 = g(%3) // %5 = inplace_(%3) // %6 = h(%5, %5) + // XXX: does not check scoping legality, consider using + // replaceAllUsesDominatedByNodeWith TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue); + // Replaces all uses of this value with 'newValue' that are dominated by + // 'node'. Given: + // x = op(...). + // if cond: + // z = foo(..) + // bar(x) + // else: + // print(x) + // x.replaceAllUsesDominatedByNodeWith(foo, z) would replace bar(x) + // but not print(x) because print is not dominated by foo. + // replaceAllUsesAfterNode does not check domination, so in this example + // it would produce invalid IR. + TORCH_API void replaceAllUsesDominatedByNodeWith( + const Node* node, + Value* newValue); + TORCH_API Value* copyMetadata(Value* from); TORCH_API std::shared_ptr> wrap() { @@ -444,6 +462,11 @@ struct TORCH_API Node { // but a new node symbol. does not destroy `this` Node* replaceWithNewSymbol(Symbol new_symbol); + // Checks if this node is dominated by `dominator` which means that + // `dominator` will always be executed before `this` and `dominator` + // is in scope of `this. + bool isDominatedBy(const Node* dominator) const; + // lots of things like chunk have a single input or single output, so we have // a helper to make accessing it easier Value* input() { @@ -1052,6 +1075,7 @@ struct Block { n->insertAfter(input_); return n; } + // clone all inputs, nodes, and outputs from src and append them // to the inputs, nodes, and outputs of this block // value_map is used whenever a node in src references a free variable diff --git a/torch/csrc/jit/ir/scope.cpp b/torch/csrc/jit/ir/scope.cpp index 474dc47cc9fdba..b3fd559dcea39e 100644 --- a/torch/csrc/jit/ir/scope.cpp +++ b/torch/csrc/jit/ir/scope.cpp @@ -88,7 +88,11 @@ InlinedCallStackPtr InlinedCallStack::intrusive_from_this() { } InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range) - : fn_(fn), source_range_(std::move(source_range)) {} + : fn_(fn), source_range_(std::move(source_range)) { + if (fn_) { + set_function_name(fn_->name()); + } +} InlinedCallStack::InlinedCallStack( Function* fn, @@ -96,7 +100,11 @@ InlinedCallStack::InlinedCallStack( c10::optional module_instance_info) : fn_(fn), source_range_(std::move(source_range)), - module_instance_info_(std::move(module_instance_info)) {} + module_instance_info_(std::move(module_instance_info)) { + if (fn_) { + set_function_name(fn_->name()); + } +} InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, @@ -104,7 +112,11 @@ InlinedCallStack::InlinedCallStack( SourceRange source_range) : callee_(std::move(callee)), fn_(fn), - source_range_(std::move(source_range)) {} + source_range_(std::move(source_range)) { + if (fn_) { + set_function_name(fn_->name()); + } +} InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, @@ -114,7 +126,11 @@ InlinedCallStack::InlinedCallStack( : callee_(std::move(callee)), fn_(fn), source_range_(std::move(source_range)), - module_instance_info_(std::move(module_instance_info)) {} + module_instance_info_(std::move(module_instance_info)) { + if (fn_) { + set_function_name(fn_->name()); + } +} c10::optional InlinedCallStack::callee() const { return callee_; @@ -132,6 +148,18 @@ SourceRange InlinedCallStack::source_range() const { return source_range_; } +Function* InlinedCallStack::function() const { + return fn_; +} + +void InlinedCallStack::set_function_name(std::string fn_name) { + fn_name_ = std::move(fn_name); +} + +std::string InlinedCallStack::function_name() const { + return fn_name_; +} + std::vector InlinedCallStack::vec() { std::vector r; c10::optional current = intrusive_from_this(); diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index c0155e5db94b51..83d4e8fdd13230 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -120,6 +120,15 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { private: c10::optional callee_; Function* fn_; + // Reason for fn_name_ even though we have fn_ + // Serialized callstack is used in circustmances where InlinedCallstack + // cannot be constructed during runtime, e.g. mobile runtime or + // delegated backends. + // Since in those cases we do not have Function* we store function name + // fn_name does not give you access to the same information that Function* + // does, however in mobile/delegated backend runtime we use InlindedCallStack + // for exception stack and for that purpose fn_name_ suffices. + std::string fn_name_; SourceRange source_range_; InlinedCallStackPtr intrusive_from_this(); c10::optional module_instance_info_; @@ -155,6 +164,12 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { // Returns the source range of the node SourceRange source_range() const; + Function* function() const; + + void set_function_name(std::string fn_name); + + std::string function_name() const; + // Return callstack as a vector of [Function, SourceRange] pairs. std::vector vec(); @@ -175,6 +190,13 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { } }; -using DebugInfoPair = std::pair; +// {source range, node name, InlinedCallStack} +// We store node name because same debug infor will be used for +// profiling as well, so we need to know op names as well. +using DebugInfoTuple = + std::tuple; +constexpr size_t kDebugInfoTupleSourceRangeIndex{0}; +constexpr size_t kDebugInfoTupleNodeNameIndex{1}; +constexpr size_t kDebugInfoTupleInlinedCSIndex{2}; } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/backport.h b/torch/csrc/jit/mobile/backport.h index 845bb12298b675..3e82a1e78aff9a 100644 --- a/torch/csrc/jit/mobile/backport.h +++ b/torch/csrc/jit/mobile/backport.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp index 74fc65719c8bd3..37eb4b781c38ff 100644 --- a/torch/csrc/jit/mobile/backport_manager.cpp +++ b/torch/csrc/jit/mobile/backport_manager.cpp @@ -6,8 +6,11 @@ #include #include #include +#include +#include #include #include +#include namespace torch { namespace jit { @@ -86,31 +89,50 @@ void selective_copy( } } -bool check_bytecode_version( - const std::vector& bytecode_values, - const int64_t expect_bytecode_version) { - if (bytecode_values.empty()) { - TORCH_WARN("Empty bytecode archive."); - return false; - } else if (bytecode_values[0] != expect_bytecode_version) { - TORCH_WARN( - "Expect bytecode version ", - expect_bytecode_version, - ", but it gets ", - bytecode_values[0]); - return false; - } - return true; +// Copy all content from reader to stringstream +void get_model_stream(PyTorchStreamReader& reader, std::stringstream& out) { + auto writer_func = [&](const void* buf, size_t nbytes) -> size_t { + out.write(static_cast(buf), nbytes); + return !out ? 0 : nbytes; + }; + PyTorchStreamWriter writer(writer_func); + selective_copy( + reader, + writer, + std::unordered_set({"version"}), + std::unordered_set()); } } // namespace -// To add next backport -// function, for example, backport_vn_to_vn-1, create an anonymous namespace -// with a backport_vn_to_vn-1 function + other necessary customized function. If -// a function can be reused by other backport functions, move it to the utility -// function group. It will be easier to split out backport_manager.cpp to -// smaller files when it grows too long. +/* + To add next backport function, for example, backport_vn_to_vn-1, create an + anonymous namespace with a backport_vn_to_vn-1 function + other necessary + customized function. If a function can be reused by other backport functions, + move it to the utility function group. It will be easier to split out + backport_manager.cpp to smaller files when it grows too long. + + How to add backport_v{i}_to_v{i-1} ? + There are two options: + 1) [Format change only, recommended] Constrcut a reader with the + input_model_stream, modify the file, and use PyTorchWriter to write it to + output_model_stream. See backport_v5_to_v4. + + 2) [Both format and content change] ]Use torch.jit.load() to load the stream, + and save it to output_model_stream. + + The first option is preferred, because it will be purely format change, and + the model doesn't need to go through inline again and model content will + remain the same. + + A note for manipulate stringstream, it's recommend to declare a new + stringstream, tmp_stream, and swap it with the argument output_model_stream + once it's ready, output_model_stream.swap(tmp_stream). Do not use + output_model_stream.clear(). It only clears out error state flag + (https://www.cplusplus.com/reference/ios/ios/clear/), while the content is the + same. It's cleaner to just declare a new one and swap. + +*/ // The functions needed for backport model from v5 to v4. namespace { @@ -145,15 +167,10 @@ void writeArchiveV4( writer.writeRecord(fname, data.data(), data.size()); } -bool backport_v5_to_v4( - PyTorchStreamReader& reader, - PyTorchStreamWriter& writer) { +std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { // 1) read from archive `bytecode` archive - std::vector bytecode_values = get_bytecode_values(reader); - if (!check_bytecode_version(bytecode_values, kBytecodeVersionV5)) { - TORCH_WARN("Incorrect bytecode version for input model."); - return false; - } + PyTorchStreamReader reader(&input_model_stream); + std::vector bytecode_values = get_bytecode_ivalues(reader); std::vector constants_values = readArchive(kArchiveNameConstants, reader).toTuple()->elements(); @@ -169,6 +186,15 @@ bool backport_v5_to_v4( "constants", "bytecode", }; + + std::stringstream ouput_model_stream; + auto writer_func = [&](const void* buf, size_t nbytes) -> size_t { + ouput_model_stream.write(static_cast(buf), nbytes); + return !ouput_model_stream ? 0 : nbytes; + }; + + PyTorchStreamWriter writer(writer_func); + selective_copy(reader, writer, excluded_files, excluded_dirs); // 3) write `bytecode` archive @@ -181,8 +207,8 @@ bool backport_v5_to_v4( // write `constants` archive auto constants_tuple = c10::ivalue::Tuple::create(std::move(constants_values)); - writeArchiveV4(writer, kArchiveNameConstants, bytecode_tuple); - return true; + writeArchiveV4(writer, kArchiveNameConstants, constants_tuple); + return ouput_model_stream; } } // namespace @@ -192,9 +218,8 @@ bool backport_v5_to_v4( // * PyTorchStreamReader has access to the input model from N bytecode version. // * PyTorchStreamWriter has access to the output model backported to the // previous N-1 bytecode version. Returns true if successful, false otherwise. -using BytecodeBackportFunction = std::function; +using BytecodeBackportFunction = + std::function; BackportManager::BackportManager() { registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4); @@ -202,15 +227,11 @@ BackportManager::BackportManager() { std::unordered_map< int64_t, - std::function>& + std::function>& BackportManager::bytecodeBackportFunctions() const { static std::unordered_map< int64_t, - std::function> + std::function> backport_functions; return backport_functions; } @@ -240,6 +261,8 @@ bool BackportManager::backport( PyTorchStreamWriter& final_writer, int64_t from_version, int64_t to_version) const { + PyTorchStreamReader start_reader(istream_adapter); + if (from_version <= to_version) { TORCH_WARN( "backport donesn't support backporting model to new version. It's trying to backport from version ", @@ -249,44 +272,74 @@ bool BackportManager::backport( return false; } int64_t bytecode_version = from_version; - std::ostringstream out; - auto writer_func = [&](const void* buf, size_t nbytes) -> size_t { - out.write(static_cast(buf), nbytes); - return !out ? 0 : nbytes; - }; - - std::shared_ptr intermediate_istream_adapter = - istream_adapter; - std::ostringstream oss; bool backport_success = true; + // 1) Given an istream_adapter (an adapter with access to the input model, the + // model can be from istream, file and etc), copy all model content to + // stringstream + std::stringstream oss; + get_model_stream(start_reader, oss); + std::stringstream input_model_stream(oss.str()); + std::stringstream output_model_stream; + + // 2) backport model, backport_v{i}_to_v{i-1} function's argurment is + // (input_model_stream and output_model_stream) while (bytecode_version > to_version) { - // Read from intermediate writer result if ostream is not empty, otherwise - // it means that it's the first time to backport and read from the source. - if (!out.str().empty()) { - std::istringstream iss(out.str()); - intermediate_istream_adapter = - std::make_shared(&iss); + // Swap input and output if it's not the first time and output_model_stream + // has value. + if (!output_model_stream.str().empty()) { + input_model_stream.swap(output_model_stream); + // reset output_model_stream + output_model_stream.str(""); } - out.clear(); - - PyTorchStreamReader intermediate_reader(intermediate_istream_adapter); - PyTorchStreamWriter intermediate_writer(writer_func); if (!hasBytecodeBackportFunction(bytecode_version)) { return false; } - // When it's the last backport process, write to the final destination - // otherwise, export to the intermediate ostream. - if (bytecode_version - 1 == to_version) { - backport_success &= bytecodeBackportFunctions()[bytecode_version--]( - intermediate_reader, final_writer); - } else { - backport_success &= bytecodeBackportFunctions()[bytecode_version--]( - intermediate_reader, intermediate_writer); + auto input_model_stream_version = + _get_model_bytecode_version(input_model_stream); + + if (input_model_stream_version != bytecode_version) { + TORCH_WARN( + "The bytecode version of input model stream is supposed to be ", + bytecode_version, + ", but it gets ", + input_model_stream_version); + return false; + } + + // Keep backporting till request version + std::stringstream backport_model_stream = + bytecodeBackportFunctions()[bytecode_version--](input_model_stream); + + output_model_stream.swap(backport_model_stream); + auto output_model_stream_version = + _get_model_bytecode_version(output_model_stream); + + if (output_model_stream_version != bytecode_version) { + TORCH_WARN( + "The bytecode version of output model stream is supposed to be ", + bytecode_version, + ", but it gets ", + output_model_stream_version); + return false; } } + + // 3) Write the final output_model_stream to final_writer, final_writer has + // access to the final model destination (file, ostream and etc) + if (output_model_stream.str().empty()) { + TORCH_WARN("No output model from backport."); + return false; + } + PyTorchStreamReader last_model_reader(&output_model_stream); + selective_copy( + last_model_reader, + final_writer, + std::unordered_set({"version"}), + std::unordered_set()); + return backport_success; } diff --git a/torch/csrc/jit/mobile/backport_manager.h b/torch/csrc/jit/mobile/backport_manager.h index 3343a047d5d2b7..6a518b391d08e9 100644 --- a/torch/csrc/jit/mobile/backport_manager.h +++ b/torch/csrc/jit/mobile/backport_manager.h @@ -30,9 +30,7 @@ class BackportManager final { std::unordered_map< int64_t, - std::function>& + std::function>& bytecodeBackportFunctions() const; bool backport( @@ -49,9 +47,8 @@ class BackportManager final { // Registry of backport functions. void registerBytecodeBackportFunction( const int64_t from_version, - const std::function& backport_function); + const std::function& + backport_function); }; } // namespace jit diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index ebb46126d96c81..c38deda6eba251 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -13,33 +13,26 @@ namespace jit { namespace { -// This function construct stacktrace with module hierarchy -// Module hierarchy will contain information about where in the -// module hierarchy this source is. For example if conv2d op -// exist in hierarcy A->B->C->Conv2d with type annotations of -// A -> TopM, B->MyModule, C->SomeModule, then module hierarchy -// will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv) -// Source level stack information will be from model source code. -std::pair getStackTraceWithModuleHierarchy( - const DebugInfoPair& source_callstack, - const std::string& root_scope_string, - const std::string& top_module_type_name) { +std::pair, std::string> getStackTraceWithModuleHierarchy( + const DebugInfoTuple& source_callstack, + const std::string& caller_name) { constexpr size_t kSourceRange = 1; constexpr size_t kModuleInstanceInfo = 2; std::vector entries; - const SourceRange& range = source_callstack.first; - InlinedCallStackPtr callstack_ptr = source_callstack.second; - std::string module_info = - root_scope_string + "(" + top_module_type_name + ")"; - std::ostringstream ss; + const SourceRange& range = + std::get(source_callstack); + InlinedCallStackPtr callstack_ptr = + std::get(source_callstack); + std::string prev_function_name = caller_name; + std::string module_info; if (!callstack_ptr) { // If not cs then top level node - entries.emplace_back(StackEntry{"FunctionName_UNKNOWN", range}); + entries.emplace_back(StackEntry{prev_function_name, range}); + return {std::move(entries), std::move(module_info)}; } else { - for (const auto& element : callstack_ptr->vec()) { - const auto& opt_module_instance_info = - std::get(element); + while (callstack_ptr) { + const auto& opt_module_instance_info = callstack_ptr->module_instance(); if (opt_module_instance_info.has_value()) { const auto& module_instance_info = opt_module_instance_info.value(); if (module_instance_info.class_type()) { @@ -52,8 +45,11 @@ std::pair getStackTraceWithModuleHierarchy( .append("(") .append(type_name) .append(")"); + } else if (!module_instance_info.instance_name().empty()) { + module_info += "." + module_instance_info.instance_name(); } else { - module_info += ".(UNKNOWN_INSTANCE(UNKNOWN_TYPE)"; + const auto& instance_name = module_instance_info.instance_name(); + module_info += "." + instance_name + "(UNKNOWN_TYPE)"; } } else { module_info += ".(UNKNOWN_INSTANCE(UNKNOWN_TYPE)"; @@ -62,12 +58,55 @@ std::pair getStackTraceWithModuleHierarchy( // When we serialize function names, those can be added here. // TODO: Add function name separately entries.emplace_back( - StackEntry{"FunctionName_UNKNOWN", std::get(element)}); + StackEntry{prev_function_name, callstack_ptr->source_range()}); + if (callstack_ptr->function()) { + prev_function_name = callstack_ptr->function()->name(); + } else { + prev_function_name = callstack_ptr->function_name(); + } + + if (callstack_ptr->callee()) { + callstack_ptr = callstack_ptr->callee().value(); + } else { + callstack_ptr = c10::intrusive_ptr(); + } } - entries.emplace_back(StackEntry{"FunctionName_UNKNOWN", range}); + entries.emplace_back(StackEntry{prev_function_name, range}); + return {std::move(entries), std::move(module_info)}; + } +} + +// This function construct stacktrace with module hierarchy +// Module hierarchy will contain information about where in the +// module hierarchy this source is. For example if conv2d op +// exist in hierarcy A->B->C->Conv2d with type annotations of +// A -> TopM, B->MyModule, C->SomeModule, then module hierarchy +// will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv) +// Source level stack information will be from model source code. +std::pair getStackTraceWithModuleHierarchy( + const std::vector& source_callstacks, + const std::string& root_scope_string, + const std::string& top_module_type_name) { + std::vector stack_entries; + std::string module_info = + root_scope_string + "(" + top_module_type_name + ")"; + std::string caller_fn_name = "FunctionName_UNKNOWN"; + for (const auto& debug_info : source_callstacks) { + auto debug_info_pair = + getStackTraceWithModuleHierarchy(debug_info, caller_fn_name); + auto entries = std::move(debug_info_pair.first); + stack_entries.insert(stack_entries.end(), entries.begin(), entries.end()); + module_info += debug_info_pair.second; } + // Only last entry in the callstack will have a node name of interest. + // Rest are likely CallMethod/CallFunction nodes + auto last_entry = source_callstacks.back(); + const std::string& node_name = + std::get(last_entry); + module_info += "." + node_name; + std::ostringstream ss; ss << "Module hierarchy:" << module_info << "\n"; - format_stack_trace(ss, entries); + format_stack_trace(ss, stack_entries); return {ss.str(), std::move(module_info)}; } @@ -121,10 +160,18 @@ std::string MobileDebugTable::getModuleHierarchyInfo( const std::string& top_module_type_name) const { const auto it = callstack_ptr_map_.find(debug_handle); if (it == callstack_ptr_map_.end()) { - return "debug_handle:" + std::to_string(debug_handle); + return "Module info for handle, " + std::to_string(debug_handle) + + ", not found."; } return (getStackTraceWithModuleHierarchy( - it->second, "top", top_module_type_name)) + {it->second}, "top", top_module_type_name)) + .second; +} + +std::string MobileDebugTable::getModuleHierarchyInfo( + const std::vector& debug_handles, + const std::string& top_module_type_name) const { + return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name) .second; } @@ -133,12 +180,49 @@ std::string MobileDebugTable::getSourceDebugString( const std::string& top_module_type_name) const { const auto it = callstack_ptr_map_.find(debug_handle); if (it == callstack_ptr_map_.end()) { - return "debug_handle:" + std::to_string(debug_handle); + return "Debug info for handle, " + std::to_string(debug_handle) + + ", not found."; } return (getStackTraceWithModuleHierarchy( - it->second, "top", top_module_type_name)) + {it->second}, "top", top_module_type_name)) .first; } +std::string MobileDebugTable::getSourceDebugString( + const std::vector& debug_handles, + const std::string& top_module_type_name) const { + return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name) + .first; +} + +std::pair MobileDebugTable:: + getSourceDebugModuleHierarchyInfo( + const std::vector& debug_handles, + const std::string& top_module_type_name) const { + std::vector debug_infos; + bool debug_handle_not_found{false}; + for (auto it = debug_handles.rbegin(); it != debug_handles.rend(); ++it) { + auto debug_handle = *it; + const auto cs_it = callstack_ptr_map_.find(debug_handle); + if (cs_it == callstack_ptr_map_.end()) { + debug_handle_not_found = true; + break; + } + debug_infos.emplace_back(cs_it->second); + } + if (debug_handle_not_found) { + std::string debug_handles_string = "debug_handles:{"; + for (const auto debug_handle : debug_handles) { + debug_handles_string += std::to_string(debug_handle); + } + debug_handles_string += "}"; + debug_handles_string = + "Debug info for handles: " + debug_handles_string + ", was not found."; + return {debug_handles_string, debug_handles_string}; + } + return (getStackTraceWithModuleHierarchy( + debug_infos, "top", top_module_type_name)); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h index e7c7f1d45e2a67..444573ccd013ee 100644 --- a/torch/csrc/jit/mobile/debug_info.h +++ b/torch/csrc/jit/mobile/debug_info.h @@ -8,14 +8,14 @@ namespace torch { namespace jit { /* * MobileDebugTable: - * Deserializes debug_pkl records from PT model's zip archive and - * stores them in a map of debug handles to source range. - * Debug handles are unique per model and runtime, be in lite interpreter - * or delegate, raises exception using debug handles. + * Deserializes debug_pkl and callstack_map records from PT model's zip archive + * and stores them in a map of debug handles to DebugInfoPair. Debug handles are + * unique per model and runtime, be in lite interpreter or delegate, an + * exception of BackendRuntimeException should raised using debug handles. * getSourceDebugString method is responsible for translating debug * handles to correspond debug information. - * At the moment this only contains information about model source. - * But later diffs will include entire stack corresponding to debug handle. + * This debug informatin includes stack trace of model level source code and + * module hierarchy where the exception occurred. */ class MobileDebugTable { public: @@ -26,12 +26,21 @@ class MobileDebugTable { std::string getSourceDebugString( const int64_t debug_handle, const std::string& top_module_type_name = "ModuleTypeUnknown") const; + std::string getSourceDebugString( + const std::vector& debug_handles, + const std::string& top_module_type_name = "ModuleTypeUnknown") const; std::string getModuleHierarchyInfo( const int64_t debug_handle, const std::string& top_module_type_name = "ModuleTypeUnknown") const; + std::string getModuleHierarchyInfo( + const std::vector& debug_handles, + const std::string& top_module_type_name = "ModuleTypeUnknown") const; private: - ska::flat_hash_map callstack_ptr_map_; + std::pair getSourceDebugModuleHierarchyInfo( + const std::vector& debug_handles, + const std::string& top_module_type_name = "ModuleTypeUnknown") const; + ska::flat_hash_map callstack_ptr_map_; }; } // namespace jit diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 63ee5a7dd2ebe7..78a77c2685ad5e 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace torch { namespace jit { char const* toString(OpCode op); @@ -230,6 +232,10 @@ bool InterpreterState::run(Stack& stack) { default: AT_ERROR(toString(inst.op), " is invalid."); } + // This exception must be caught first as it derived from c10::Error + } catch (c10::BackendRuntimeException& e) { + exception_pc_ = pc; + TORCH_RETHROW(e); } catch (c10::Error& error) { // Reason for catching and rethrowing the error is so that we can // set the exception pc that is queried later diff --git a/torch/csrc/jit/mobile/model_compatibility.cpp b/torch/csrc/jit/mobile/model_compatibility.cpp index 9cbc650547837a..b74793f3c0dca7 100644 --- a/torch/csrc/jit/mobile/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/model_compatibility.cpp @@ -49,12 +49,14 @@ c10::IValue readArchive( return ivalues; } -std::vector get_bytecode_values(PyTorchStreamReader& reader) { +std::vector get_bytecode_ivalues(PyTorchStreamReader& reader) { std::vector bytecode_values; bytecode_values = readArchive("bytecode", reader).toTuple()->elements(); return bytecode_values; } +/********************** Bytecode **********************/ + // Forward declare int64_t _get_model_bytecode_version( const std::vector& bytecode_ivalues); @@ -76,7 +78,7 @@ int64_t _get_model_bytecode_version(std::shared_ptr rai) { return -1; } PyTorchStreamReader reader(std::move(rai)); - auto bytecode_values = get_bytecode_values(reader); + auto bytecode_values = get_bytecode_ivalues(reader); return _get_model_bytecode_version(bytecode_values); } @@ -90,5 +92,87 @@ int64_t _get_model_bytecode_version( return -1; } +/********************** Operators and Info **********************/ + +// Forward declare +std::unordered_map _get_model_ops_and_info( + std::vector bytecode_ivalues); + +std::unordered_map _get_model_ops_and_info( + std::istream& in) { + std::unique_ptr rai = std::make_unique(&in); + return _get_model_ops_and_info(std::move(rai)); +} + +std::unordered_map _get_model_ops_and_info( + const std::string& filename) { + std::unique_ptr rai = std::make_unique(filename); + return _get_model_ops_and_info(std::move(rai)); +} + +std::unordered_map _get_model_ops_and_info( + std::shared_ptr rai) { + if (!check_zip_file(rai)) { + TORCH_WARN("Failed to open zip file for model ops."); + return std::unordered_map{}; + } + PyTorchStreamReader reader(std::move(rai)); + auto bytecode_values = get_bytecode_ivalues(reader); + return _get_model_ops_and_info(bytecode_values); +} + +/* A function to retrieve the root (top level) operators of a model and their + * corresponding compatibility info. These root operators can call other + * operators within them (traced ops), and a root op can call many different + * traced ops depending on internal code paths in the root op. These traced ops + * are not returned by this function. Those operators are abstracted into the + * runtime as an implementation detail (and the traced ops themselves can also + * call other operators) making retrieving them difficult and their value from + * this api negligible since they will differ between which runtime version the + * model is run on. Because of this, there is a false positive this api can't + * prevent in a compatibility usecase. All the root ops of a model are present + * in a target runtime, but not all the traced ops are which prevents a model + * from being able to run. + **/ +std::unordered_map _get_model_ops_and_info( + std::vector bytecode_ivalues) { + constexpr uint64_t min_version_with_schema = 6; + if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) { + TORCH_WARN( + "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it"); + } + std::unordered_map result; + if (bytecode_ivalues.empty()) { + TORCH_WARN("Failed to get model ops and info."); + return result; + } + // loop over all the functions in the bytecode + for (int i = 1; i < bytecode_ivalues.size(); i++) { + // descend to the operators list + auto method_tuple = bytecode_ivalues.at(i).toTuple()->elements(); + auto operators_tuple = method_tuple.at(1).toTuple()->elements()[1]; + auto operators = operators_tuple.toTuple()->elements()[1]; + for (auto& op_tuple : operators.toTuple()->elements()) { + auto op = op_tuple.toTuple()->elements(); + + // grab name + std::string op_name = op.at(0).toStringRef(); + std::string op_overload_name = op.at(1).toStringRef(); + if (op_overload_name != "") { + op_name.append("."); + op_name.append(op_overload_name); + } + + // grab schema size + if (op.size() > 2) { + result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()}); + } else { // no schema information use default + result.emplace(op_name, OperatorInfo{}); + } + } + } + return result; +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/model_compatibility.h b/torch/csrc/jit/mobile/model_compatibility.h index 1bbaebfe21ff9d..dc5808a0d84eef 100644 --- a/torch/csrc/jit/mobile/model_compatibility.h +++ b/torch/csrc/jit/mobile/model_compatibility.h @@ -1,7 +1,11 @@ #pragma once +#include +#include + #include #include +#include namespace caffe2 { namespace serialize { @@ -24,7 +28,7 @@ TORCH_API int64_t _get_model_bytecode_version( int64_t _get_model_bytecode_version( const std::vector& bytecode_ivalues); -std::vector get_bytecode_values( +std::vector get_bytecode_ivalues( caffe2::serialize::PyTorchStreamReader& reader); c10::IValue readArchive( @@ -34,5 +38,15 @@ c10::IValue readArchive( bool check_zip_file( std::shared_ptr rai); +// The family of methods below to get the root ops and information from a model +TORCH_API std::unordered_map _get_model_ops_and_info( + std::istream& in); + +TORCH_API std::unordered_map _get_model_ops_and_info( + const std::string& filename); + +TORCH_API std::unordered_map _get_model_ops_and_info( + std::shared_ptr rai); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 7b9d240a72ad25..9814aa197156e5 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -184,6 +185,18 @@ void Method::run(Stack& stack) const { if (observer) { observer->onExitRunMethod(instance_key); } + // This exception must be caught first as it derived from c10::Error + } catch (c10::BackendRuntimeException& e) { +#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) + e.pushDebugHandle(function_->getExceptionDebugHandle()); + // symbolicate all handles + e.add_context(owner_->getDebugTable().getSourceDebugString( + e.getDebugHandles(), getTopModuleTypeName(*owner_))); +#endif + if (observer) { + observer->onFailRunMethod(instance_key, e.what()); + } + TORCH_RETHROW(e); } catch (c10::Error& error) { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) auto debug_string = owner_->getDebugTable().getSourceDebugString( diff --git a/torch/csrc/jit/mobile/runtime_compatibility.h b/torch/csrc/jit/mobile/runtime_compatibility.h index d7f476eace8935..58e45b4a7fa64b 100644 --- a/torch/csrc/jit/mobile/runtime_compatibility.h +++ b/torch/csrc/jit/mobile/runtime_compatibility.h @@ -1,12 +1,17 @@ #pragma once +#include +#include + #include #include namespace torch { namespace jit { +// Struct storing metadata of an operator that can be useful for versioning struct OperatorInfo { + // The number of arguments within the schema of the op c10::optional num_schema_args; }; diff --git a/torch/csrc/jit/passes/fuse_linear.cpp b/torch/csrc/jit/passes/fuse_linear.cpp index e3dddb1db6857a..0f1978e8a9445d 100644 --- a/torch/csrc/jit/passes/fuse_linear.cpp +++ b/torch/csrc/jit/passes/fuse_linear.cpp @@ -33,7 +33,10 @@ void FuseLinear(std::shared_ptr& graph) { // replace addmm pattern to linear SubgraphRewriter addmm_to_linear; - addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear_addmm); + std::vector> value_mappings( + {{"weight", "res"}, {"res", "res"}}); + addmm_to_linear.RegisterRewritePattern( + addmm_pattern, fused_linear_addmm, value_mappings); addmm_to_linear.runOnGraph( graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed}); @@ -47,10 +50,11 @@ void FuseLinear(std::shared_ptr& graph) { %weight = aten::t(%weight_t) %res = aten::linear(%input, %weight, %bias) return (%res))IR"; + value_mappings = {{"weight", "output"}, {"res", "output"}}; // replace matmul + add pattern to linear SubgraphRewriter matmuladd_to_linear; matmuladd_to_linear.RegisterRewritePattern( - matmul_add_pattern, fused_linear_matmul); + matmul_add_pattern, fused_linear_matmul, value_mappings); matmuladd_to_linear.runOnGraph( graph, {aten_add_alpha_is_one, weight_transposed}); @@ -68,7 +72,7 @@ void FuseLinear(std::shared_ptr& graph) { // replace matmul with bias=None pattern to linear SubgraphRewriter matmul_to_linear; matmul_to_linear.RegisterRewritePattern( - matmul_pattern, fused_linear_bias_none); + matmul_pattern, fused_linear_bias_none, value_mappings); matmul_to_linear.runOnGraph(graph, weight_transposed); // clean up extra transpose for the weight of aten::linear @@ -84,9 +88,12 @@ void FuseLinear(std::shared_ptr& graph) { %res = aten::linear(%input, %weight, %bias) return (%res))IR"; + value_mappings = {{"res", "res"}}; SubgraphRewriter cleanup; cleanup.RegisterRewritePattern( - linear_weight_extra_transpose, linear_weight_no_transpose); + linear_weight_extra_transpose, + linear_weight_no_transpose, + value_mappings); cleanup.runOnGraph(graph); } } // namespace jit diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index c24e89ff41934f..de8b8ede88ceb9 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -117,16 +117,6 @@ struct GuardElimination { } } - bool isDominatedBy(Node* node, Node* dominator) { - while (node) { - if (node->owningBlock() == dominator->owningBlock()) { - return dominator->isBefore(node); - } - node = node->owningBlock()->owningNode(); - } - return false; - } - void removeDominatedGuards(Block* b) { // If a Node guards a value which isn't mutated, then that node // can replace all other guards of the value which it dominates @@ -150,7 +140,7 @@ struct GuardElimination { continue; } - if (!isDominatedBy(use.user, n)) { + if (!use.user->isDominatedBy(n)) { continue; } diff --git a/torch/csrc/jit/passes/integer_value_refinement.cpp b/torch/csrc/jit/passes/integer_value_refinement.cpp new file mode 100644 index 00000000000000..ddb0dc6beae46e --- /dev/null +++ b/torch/csrc/jit/passes/integer_value_refinement.cpp @@ -0,0 +1,228 @@ +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using IntegerRefinement = std::unordered_map; + +// see [value refinement algorithm] for full explanation. +// When a comparison like `cond = x == 4` or `cond = x != 4` is made, +// `cond` value carries information (refinements) about the value of `x`. +// in an example like: +// if x == 1: +// ... +// we can substitute all uses of x dominated by the true block +// with 1. + +struct IntegerValueRefiner { + IntegerValueRefiner(std::shared_ptr graph) + : graph_(std::move(graph)) {} + + bool run() { + if (!blockHasIntComparisons(graph_->block())) { + return false; + } + IntegerRefinement refinements; + RefineIntegerValues(graph_->block(), refinements); + return changed_; + } + + bool blockHasIntComparisons(Block* b) { + for (Node* n : b->nodes()) { + if (n->matches("aten::eq(int a, int b) -> bool") || + n->matches("aten::ne(int a, int b) -> bool")) { + for (size_t const_index : {0, 1}) { + auto non_const_index = 1 - const_index; + if (n->inputs().at(const_index)->node()->kind() == prim::Constant && + n->inputs().at(non_const_index)->uses().size() > 1) { + return true; + } + } + } + for (Block* block : n->blocks()) { + if (blockHasIntComparisons(block)) { + return true; + } + } + } + return false; + } + + void removeIfNodeOutputsWithRefinements( + Node* if_node, + IntegerRefinement& true_block_refinements, + IntegerRefinement& false_block_refinements) { + // we are looking for cases where we can replace both block outputs with the + // same value, which opens up further optimization opportunities. The pass + // will already handle if both outputs are refined to the same constant. + // Here, we look for cases where one block output has been refined in the + // other block to be equal to the same constant value as the other other + // block output: + // graph(%y.1 : int): + // %one_constant : int = prim::Constant[value=1]() + // %3 : bool = aten::eq(%y.1, %one_constant) + // %15 : int = prim::If(%3) + // block0(): + // -> (%one_constant) + // block1(): + // -> (%y.1) + // return (%15) + // %15 can always be safely replaced with %y.1 + // this is an important case for symbolic shape analysis + for (size_t block_index : {0, 1}) { + Block* if_block = if_node->blocks().at(block_index); + Block* other_if_block = if_node->blocks().at(1 - block_index); + for (size_t i = 0; i < if_node->outputs().size(); ++i) { + Value* block_output = if_block->outputs().at(i); + if (!block_output->type()->cast()) { + continue; + } + // Value must be in scope for both blocks + // in example above, %y.1 cannot be defined in block1 + if (!if_node->isDominatedBy(block_output->node())) { + continue; + } + // one constant value one not - we are looking for the pattern + // where y.1 is refined to the existing block output %one_constant + auto other_const_value = + constant_as(other_if_block->outputs().at(i)); + if (!other_const_value || + block_output->node()->kind() == prim::Constant) { + continue; + } + // here, we are looking in refinements in the other block of our + // current output. in the example, we are looking for refinements of + // %y.1 in `block0`, and we are checking that %y.1 is refined + // to the constant value of %one_constant + const auto& other_block_refinements = + block_index == 0 ? false_block_refinements : true_block_refinements; + if (!other_block_refinements.count(block_output)) { + continue; + } + if (other_block_refinements.at(block_output) == *other_const_value) { + if_node->outputs().at(i)->replaceAllUsesWith(block_output); + changed_ = true; + } + } + } + } + + // iteratively look through the block `b` for refinements or Value uses that + // can be refined, `block_refinements` are the refinements present starting at + // this block (and for all blocks dominated by this block). + IntegerRefinement RefineIntegerValues( + Block* b, + IntegerRefinement block_refinements) { + active_refinements_.push_back(&block_refinements); + for (Node* n : b->nodes()) { + if (n->matches("aten::eq(int a, int b) -> bool") || + n->matches("aten::ne(int a, int b) -> bool")) { + for (size_t const_index : {0, 1}) { + if (auto ival = constant_as(n->inputs().at(const_index))) { + IntegerRefinement refine; + refine[n->inputs().at(1 - const_index)] = *ival; + info_[n->output()] = n->kind() == aten::eq + ? BooleanRefinementMapping::TrueRefinements(std::move(refine)) + : BooleanRefinementMapping::FalseRefinements(std::move(refine)); + } + } + } + for (size_t input = 0; input < n->inputs().size(); ++input) { + Value* input_v = n->inputs().at(input); + if (!input_v->type()->cast()) { + continue; + } + + if (auto refine = tryFindRefinement(input_v)) { + WithInsertPoint guard(n); + auto refine_constant = + graph_->insertConstant(static_cast(*refine)); + n->replaceInputWith(input_v, refine_constant); + changed_ = true; + } + } + + if (n->kind() == prim::If) { + IfView if_n(n); + bool has_cond_ref = info_.count(if_n.cond()) != 0; + IntegerRefinement empty; + auto true_block_refinements = RefineIntegerValues( + if_n.thenBlock(), + has_cond_ref ? info_[if_n.cond()].true_refine() : empty); + auto false_block_refinements = RefineIntegerValues( + if_n.elseBlock(), + has_cond_ref ? info_[if_n.cond()].false_refine() : empty); + + removeIfNodeOutputsWithRefinements( + n, true_block_refinements, false_block_refinements); + + joinIfRefinements( + n, + throwing_blocks_, + block_refinements, + true_block_refinements, + false_block_refinements, + info_); + } else { + handleCommonRefinentOperators(n, throwing_blocks_, info_); + } + } + + // iterating over all nodes in the block will not iterate over + // block outputs, so we need to add handling of them. + // %3 : int = prim::Constant[value=3]() + // %4 : bool = aten::eq(%y.1, %3) + // %a : int = prim::If(%4) + // block0(): + // -> (%y.1) + // Here, we can replace y.1 with 3 + + for (size_t i = 0; i < b->outputs().size(); ++i) { + Value* output_v = b->outputs().at(i); + if (!output_v->type()->cast()) { + continue; + } + + if (auto refine = tryFindRefinement(output_v)) { + WithInsertPoint guard(b); + auto refine_constant = + graph_->insertConstant(static_cast(*refine)); + b->replaceOutput(i, refine_constant); + changed_ = true; + } + } + + active_refinements_.pop_back(); + return block_refinements; + }; + + c10::optional tryFindRefinement(Value* v) { + for (const auto& ref : active_refinements_) { + auto maybe_refinement = ref->find(v); + if (maybe_refinement != ref->end()) { + return maybe_refinement->second; + } + } + return c10::nullopt; + } + + std::shared_ptr graph_; + // A stack of active refinements, one for each block + std::vector active_refinements_; + // A map from Boolean Value * -> associated refinements + std::unordered_map info_; + std::unordered_set throwing_blocks_; + bool changed_ = false; +}; + +bool RefineIntegerValues(const std::shared_ptr& graph) { + return IntegerValueRefiner(graph).run(); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/integer_value_refinement.h b/torch/csrc/jit/passes/integer_value_refinement.h new file mode 100644 index 00000000000000..5614e96c141f4b --- /dev/null +++ b/torch/csrc/jit/passes/integer_value_refinement.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// return true if graph is modified +TORCH_API bool RefineIntegerValues(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/loop_unrolling.cpp b/torch/csrc/jit/passes/loop_unrolling.cpp index 6a182973d3ccee..92b37936a54666 100644 --- a/torch/csrc/jit/passes/loop_unrolling.cpp +++ b/torch/csrc/jit/passes/loop_unrolling.cpp @@ -162,11 +162,9 @@ void replaceLoopCounter(Node* loop) { body->insertOutput(1, result); } -bool unroll(Node* loop) { +void unroll(Node* loop) { Graph* graph = loop->owningGraph(); Block* body = loop->blocks().at(0); - if (!isSmallBlock(body)) - return false; // We will be using a "mutable" counter outside of the loop instead of the // default one, because this will allow us to share it between the unrolled @@ -184,7 +182,7 @@ bool unroll(Node* loop) { repeatBody(body, *const_len, dest); loop->eraseBlock(0); inlineBody(loop); - return true; + return; } WithInsertPoint insert_point_guard{loop}; @@ -212,11 +210,9 @@ bool unroll(Node* loop) { aten::sub, {iter_count, graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})})); - - return true; } -bool UnrollLoops(Block* block) { +bool UnrollLoops(Block* block, bool constant_only) { bool changed = false; for (auto it = block->nodes().begin(); it != block->nodes().end();) { // XXX: unroll might destroy the current node, so we need to pre-increment @@ -224,11 +220,21 @@ bool UnrollLoops(Block* block) { Node* node = *it; ++it; for (Block* subblock : node->blocks()) { - changed |= UnrollLoops(subblock); + changed |= UnrollLoops(subblock, constant_only); } - if (isForLoop(node)) { - changed |= unroll(node); + if (!isForLoop(node)) { + continue; } + if (constant_only) { + if (node->inputs().at(0)->node()->kind() != prim::Constant) { + continue; + } + } else if (!isSmallBlock(node->blocks().at(0))) { + continue; + } + + unroll(node); + changed = true; } return changed; } @@ -366,7 +372,15 @@ Node* PeelLoop(Node* n, size_t times) { } bool UnrollLoops(std::shared_ptr& graph) { - bool changed = UnrollLoops(graph->block()); + bool changed = UnrollLoops(graph->block(), false); + if (changed) { + EliminateDeadCode(graph); + } + return changed; +} + +bool UnrollConstantLoops(std::shared_ptr& graph) { + bool changed = UnrollLoops(graph->block(), true); if (changed) { EliminateDeadCode(graph); } diff --git a/torch/csrc/jit/passes/loop_unrolling.h b/torch/csrc/jit/passes/loop_unrolling.h index b823f88b1f0f8a..5895f2fcee7462 100644 --- a/torch/csrc/jit/passes/loop_unrolling.h +++ b/torch/csrc/jit/passes/loop_unrolling.h @@ -8,6 +8,9 @@ namespace jit { // return true if graph is modified TORCH_API bool UnrollLoops(std::shared_ptr& graph); +// Only unrolls constant loops. Will unroll them regardless of loop block size +TORCH_API bool UnrollConstantLoops(std::shared_ptr& graph); + TORCH_API Node* PeelLoop(Node* n, size_t times); // return true if graph is modified diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index b2b8b6df1d6011..ff32e80315e70a 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -172,11 +172,17 @@ std::shared_ptr ToONNX( return new_graph; } -void BlockToONNX( +// BlockToONNX. +// is_sub_block = true means the old_block (aten graph) is in the sub block +// (e.g., if sub block), and we want to convert it into its parent block in onnx +// graph. In this case, we don't register the input/output or eliminate the dead +// code. +std::unordered_map BlockToONNX( Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, - std::unordered_map env) { + std::unordered_map& env, + bool is_sub_block) { torch::autograd::SymbolicContext ctx{}; ctx.block = new_block; @@ -185,24 +191,32 @@ void BlockToONNX( old_block->owningGraph()->toString()); // Initialize context and environment - for (auto input : old_block->inputs()) { - auto n = ctx.block->addInput()->copyMetadata(input); - env[input] = n; + if (!is_sub_block) { + for (auto input : old_block->inputs()) { + auto n = ctx.block->addInput()->copyMetadata(input); + env[input] = n; + } } // Finally, visit all nodes in the graph for (auto node : old_block->nodes()) { NodeToONNX(node, ctx.block, operator_export_type, env); } + + if (is_sub_block) { + return env; + } + for (auto output : old_block->outputs()) { ctx.block->registerOutput(env.at(output)); } - // Run dce to clean-up unused functional and inplace ops. EliminateDeadCode( ctx.block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + + return {}; } void NodeToONNX( @@ -240,6 +254,10 @@ void NodeToONNX( ss << num_old_outputs << ", but got " << outputs.size() << ")"; throw std::runtime_error(ss.str()); } + // For const node, it does not need params_dict info, so set it to {}. + const ParamMap empty_params_dict = {}; + auto opset_version = + py::cast(onnx_symbolic.attr("_export_onnx_opset_version")); for (size_t i = 0; i < num_old_outputs; ++i) { auto old = old_outputs[i]; if (outputs[i]) { @@ -249,13 +267,34 @@ void NodeToONNX( // // If onnx shape inference is turned on, the new outputs will have // types inferred, and they will be merged with the old types. - outputs[i]->setType(MergeInferredType(old->type(), outputs[i]->type())); - - // Copy over source location and scope information to all nodes - // created by the symbolic - outputs[i]->node()->setSourceRange(node->sourceRange()); - outputs[i]->node()->setScope(node->scope()); - env[old] = outputs[i]; + if (outputs[i]->node()->kind() != c10::onnx::Constant && + ConstantValueMap::HasValue(outputs[i]->debugName())) { + // Create a const node if the node output value is in + // ConstantValueMap. + auto value = + ConstantValueMap::GetValue(outputs[i]->debugName()).value(); + Node* const_node = + new_block->owningGraph()->create(c10::onnx::Constant); + const_node->t_(attr::value, value); + const_node->output()->setType(TensorType::create(value)); + + // Copy over source location and scope information to all nodes + // created by the symbolic + const_node->output()->node()->setSourceRange(node->sourceRange()); + const_node->output()->node()->setScope(node->scope()); + new_block->appendNode(const_node); + ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); + env[old] = const_node->output(); + } else { + outputs[i]->setType( + MergeInferredType(old->type(), outputs[i]->type())); + + // Copy over source location and scope information to all nodes + // created by the symbolic + outputs[i]->node()->setSourceRange(node->sourceRange()); + outputs[i]->node()->setScope(node->scope()); + env[old] = outputs[i]; + } } else { // Null output means that the ONNX op doesn't have outputs corresponding // to certain PyTorch outputs diff --git a/torch/csrc/jit/passes/onnx.h b/torch/csrc/jit/passes/onnx.h index f7106bc33378a3..e3c6cd23ecc3e7 100644 --- a/torch/csrc/jit/passes/onnx.h +++ b/torch/csrc/jit/passes/onnx.h @@ -2,7 +2,9 @@ #include #include +#include #include +#include namespace torch { namespace jit { @@ -10,11 +12,12 @@ namespace jit { TORCH_API std::shared_ptr ToONNX( std::shared_ptr& state, ::torch::onnx::OperatorExportTypes operator_export_type); -TORCH_API void BlockToONNX( +TORCH_API std::unordered_map BlockToONNX( Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, - std::unordered_map env); + std::unordered_map& env, + bool is_sub_block = false); TORCH_API void NodeToONNX( Node* old_node, Block* new_block, diff --git a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp index 0c124fd328f29a..7aa0329af9653c 100644 --- a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp +++ b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp @@ -28,6 +28,7 @@ void CastAllConstantToFloating(Block* block) { if (node->kind() == onnx::Constant) { auto val = node->t(attr::value); at::ScalarType dtype = val.scalar_type(); + auto val_type = TensorType::create(val); if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float && dtype != at::ScalarType::Half) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -55,6 +56,7 @@ void CastAllConstantToFloating(Block* block) { node->t_(attr::value, val); Node* cast_node = graph->create(onnx::Cast, 1); cast_node->i_(attr::to, to_type); + cast_node->output()->setType(val_type); cast_node->insertAfter(node); // get input from cast node node->outputs().at(0)->replaceAllUsesWith(cast_node->outputs().at(0)); diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 226cc8c0ff05ef..364327cc414897 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -278,7 +278,13 @@ c10::optional runTorchBackendForOnnx( updated_val = at::sqrt(inputTensorValues[0]); return c10::optional(updated_val); } else if (node->kind() == onnx::Div) { + // One example shows at::div(CPULongType, CPULongType) = CPUFloatType, + // So we add a cast below. updated_val = at::div(inputTensorValues[0], inputTensorValues[1]); + if (inputTensorValues[0].scalar_type() == + inputTensorValues[1].scalar_type()) { + updated_val = updated_val.to(inputTensorValues[0].scalar_type()); + } return c10::optional(updated_val); } else if (node->kind() == onnx::Mul) { updated_val = at::mul(inputTensorValues[0], inputTensorValues[1]); @@ -303,7 +309,11 @@ c10::optional runTorchBackendForOnnx( std::vector axes; for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) { // ONNX unsqueeze accepts negative axes - axes_a[i] += axes_a[i] < 0 ? inputTensorValues[0].sizes().size() : 0; + // From https://pytorch.org/docs/stable/generated/torch.unsqueeze.html + // Negative dim will correspond to unsqueeze() applied at dim = dim + + // input.dim() + 1. + axes_a[i] += + axes_a[i] < 0 ? inputTensorValues[0].sizes().size() + 1 : 0; axes.push_back(axes_a[i]); } std::sort(axes.begin(), axes.end()); @@ -471,6 +481,23 @@ c10::optional runTorchBackendForOnnx( } else if (node->kind() == onnx::Neg) { updated_val = at::neg(inputTensorValues[0]); return c10::optional(updated_val); + } else if (node->kind() == onnx::Not) { + auto ones = + at::ones(inputTensorValues[0].sizes(), inputTensorValues[0].dtype()); + updated_val = at::ne(inputTensorValues[0], ones); + return c10::optional(updated_val); + } else if (node->kind() == onnx::Size) { + int64_t total_size = 1; + for (auto size : inputTensorValues[0].sizes()) { + total_size *= size; + } + auto options = c10::TensorOptions().dtype(at::kLong).device(at::kCPU); + std::vector size_data = {total_size}; + auto f = at::from_blob(size_data.data(), {1}, at::kLong).to(at::kCPU); + // Need copy here + at::Tensor f_copy = at::empty({1}, options); + f_copy.copy_(f); + return c10::optional(at::squeeze(f_copy, 0)); } else { return c10::nullopt; } @@ -591,7 +618,6 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) { std::make_pair(newSourceNodeOutput->debugName(), updatedVal)}); newSourceNodeOutput->inferTypeFrom(updatedVal); node->outputs().at(0)->replaceAllUsesWith(newSourceNodeOutput); - // Next we remove the current node that has been replaced by // an initializer. But before we start de-wiring this node, // we check if any parents of this nodes were onnx::Constant diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 7f0d0b44cf0101..d03dc617353fa1 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -818,10 +818,6 @@ static void fuseLogSoftmaxNllLoss(Block* b) { // %26 : Long(3, 1, 2) = onnx::Reshape(%target.1, %25) // %30 : Float() = onnx::NegativeLogLikelihoodLoss[reduction="sum"](%22, // %26) return (%30) - TORCH_INTERNAL_ASSERT( - prev->input(1)->node()->input(0)->node()->kind() == onnx::Gather); - TORCH_INTERNAL_ASSERT( - prev->input(1)->node()->input(1)->node()->kind() == onnx::Gather); origLogSoftmaxNode = prev->input(0)->node()->input(0)->node(); auto transpose = origLogSoftmaxNode->input(0)->node(); TORCH_INTERNAL_ASSERT(transpose->kind() == onnx::Transpose); diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 67c1dfd61f4272..bd8566fa38f8ad 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -1078,11 +1078,21 @@ bool IsListConstructIntType(const Value* v) { return false; } +bool AllGraphInputsStatic(const Graph* g) { + for (auto n : g->inputs()) { + if (!n->isCompleteTensor()) { + return false; + } + } + return true; +} + void ProcessConstantValueMap(Node* n, int opset_version) { // Update ConstantValueMap on node outputs from onnx shape inference // For outputs, only update static shapes. For input, we update symbolic // shapes also. ONNX If can have different types on different branches, skip // here. + auto static_input_shape = AllGraphInputsStatic(n->owningGraph()); for (auto i = 0; i < n->outputs().size(); i++) { if (TensorTypePtr output_type = n->output(i)->type()->cast()) { if (output_type->dim().has_value()) { @@ -1103,9 +1113,17 @@ void ProcessConstantValueMap(Node* n, int opset_version) { if (input_type->dim().has_value()) { size_t rank = static_cast(input_type->dim().value()); ConstantValueMap::SetRank(n->input(i)->debugName(), rank); - auto shape = input_type->symbolic_sizes(); - if (!ConstantValueMap::HasShape(n->input(i)->debugName())) { - UpdateShape(n->input(i), shape); + // Only update shape if the input is onnx node. + // If it is aten operators, for example, + // Float(20, 20, strides=[1, 0], requires_grad=0, device=cpu), + // %399 : Float(20, 20, strides=[0, 1], requires_grad=0, device=cpu) + // = prim::ListUnpack(%397) + // The tracer shape may not be correct when dynamic_axes is enabled. + if (n->input(i)->node()->kind().is_onnx() || static_input_shape) { + auto shape = input_type->symbolic_sizes(); + if (!ConstantValueMap::HasShape(n->input(i)->debugName())) { + UpdateShape(n->input(i), shape); + } } } } else if (IsListConstructIntType(n->input(i))) { diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 42bd043f4b97a0..b189f205139d22 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -31,6 +32,7 @@ struct PeepholeOptimizeImpl { bool changed = optimizeBlock(graph_->block()); changed |= PeepholeOptimizeListIdioms(graph_); changed |= PeepholeOptimizeAliasSensitive(graph_); + changed |= PeepholeOptimizeNonTensor(graph_); return changed; } @@ -52,18 +54,6 @@ struct PeepholeOptimizeImpl { changed |= optimizeBlock(sub_block); } - if (node->kind() != prim::Constant) { - WithInsertPoint guard(node); - // Any Value whose type is None should be replaced with a Constant - // This can occur if a module has an optional attribute, and it is - // initialized as None. - for (Value* output : node->outputs()) { - if (output->type()->cast()) { - output->replaceAllUsesWith(graph_->insertConstant(IValue())); - changed = true; - } - } - } // XXX: remember that if you want to simplify an expression by combining // multiple nodes into a different one, then you need to check that they // all belong to the given block @@ -214,79 +204,6 @@ struct PeepholeOptimizeImpl { auto new_constant = node->owningGraph()->insertConstant(ival); node->output()->replaceAllUsesWith(new_constant); } - } else if (node->kind() == prim::If) { - IfView n(node); - // this handles redundant short circuits like "x and True" or "x or - // False" - for (size_t i = 0; i < n.outputs().size(); ++i) { - if (n.outputs().at(i)->type() != BoolType::get()) { - continue; - } - bool true_val = - constant_as(n.thenOutputs().at(i)).value_or(false); - bool false_val = - constant_as(n.elseOutputs().at(i)).value_or(true); - // if an if node's output equals its condition replace output with - // condition - if (true_val && !false_val) { - GRAPH_UPDATE( - "Replacing ", - n.outputs().at(i)->debugName(), - " (True or False) with ", - n.cond()->debugName()); - n.outputs().at(i)->replaceAllUsesWith(n.cond()); - changed = true; - } - } - } else if ( - node->kind() == aten::__is__ || node->kind() == aten::__isnot__) { - // if we are comparing a None value with a value that can't be None - // replace the output with true if node is __isnot__ or false if node is - // __is__ - AT_ASSERT(node->inputs().size() == 2); - for (size_t check_none_index : {0, 1}) { - bool input_must_be_none = - node->inputs().at(check_none_index)->mustBeNone(); - bool other_must_not_be_none = - node->inputs().at(1 - check_none_index)->mustNotBeNone(); - if (input_must_be_none && other_must_not_be_none) { - WithInsertPoint guard(node); - auto output = node->owningGraph()->insertConstant( - node->kind() == aten::__isnot__); - GRAPH_UPDATE( - "Folding ", getHeader(node), " to ", output->debugName()); - node->output()->replaceAllUsesWith(output); - changed = true; - } - } - } else if ( - node->kind() == prim::unchecked_unwrap_optional || - node->kind() == aten::_unwrap_optional) { - // we are unwrapping an input that can't be None, remove the unwrap - auto input = node->input(); - if (input->mustNotBeNone()) { - GRAPH_UPDATE( - "Unwrapping ", - getHeader(node), - " as ", - node->input(), - " can't be optional"); - node->output()->replaceAllUsesWith(node->input()); - changed = true; - } - } else if (node->kind() == prim::unchecked_cast) { - // unchecked_cast is not generated for tensor properties, so we are not - // losing anything by calling unshapedType here - auto input_type = unshapedType(node->input()->type()); - auto output_type = unshapedType(node->output()->type()); - if (input_type->isSubtypeOf(output_type)) { - GRAPH_UPDATE( - "Removing ", - getHeader(node), - " as input type subtypes output type"); - node->output()->replaceAllUsesWith(node->input()); - changed = true; - } } else if ( node->matches("prim::dtype(Tensor a) -> int") && shape_peepholes_) { auto ptt = node->input()->type()->expect(); diff --git a/torch/csrc/jit/passes/peephole_list_idioms.cpp b/torch/csrc/jit/passes/peephole_list_idioms.cpp index 27f1446bceef40..f0e4ef5f14337e 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_list_idioms.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include @@ -21,17 +23,151 @@ c10::optional normalizeIndex(int64_t index, size_t len) { } } +// see [value refinement algorithm] + +struct ListLenRefiner { + ListLenRefiner( + std::shared_ptr graph, + std::unordered_set& mutated_lists) + : graph_(std::move(graph)), mutated_lists_(mutated_lists) {} + + bool run() { + std::unordered_set li_with_len_use; + collectListsToRefine(graph_->block(), li_with_len_use); + if (lists_to_refine_.size() == 0) { + return false; + } + ListRefinement refinements; + RefineListLens(graph_->block(), refinements); + return changed_; + } + + // we only need to analyze lists that have multiple uses of len(), and we can + // only analyze lists that are not mutated + void collectListsToRefine( + Block* b, + std::unordered_set& li_with_len_use) { + for (Node* n : b->nodes()) { + for (Block* block : n->blocks()) { + collectListsToRefine(block, li_with_len_use); + } + + if (n->kind() != aten::len) { + continue; + } + + auto first_input = n->input(0); + if (first_input->type()->cast() && + !mutated_lists_.count(first_input)) { + if (!li_with_len_use.count(first_input)) { + li_with_len_use.insert(first_input); + } else { + lists_to_refine_.insert(first_input); + } + } + } + } + + ListRefinement RefineListLens(Block* b, ListRefinement block_refinements) { + active_refinements_.push_back(&block_refinements); + for (Node* n : b->nodes()) { + if (n->matches("aten::eq(int a, int b) -> bool") || + n->matches("aten::ne(int a, int b) -> bool")) { + // check for one input constant and the other coming from len(li) + for (size_t const_index : {0, 1}) { + auto ival = constant_as(n->input(const_index)); + if (!ival) { + continue; + } + auto li_len = n->input(const_index - 1); + if (!li_len->node()->matches("aten::len.t(t[] a) -> int") || + !lists_to_refine_.count(li_len->node()->input())) { + continue; + } + ListRefinement refine; + refine[li_len->node()->input()] = *ival; + boolean_value_refinements_[n->output()] = n->kind() == aten::eq + ? BooleanRefinementMapping::TrueRefinements(std::move(refine)) + : BooleanRefinementMapping::FalseRefinements(std::move(refine)); + } + } else if (n->kind() == aten::len) { + if (auto maybe_len = tryFindRefinement(n->input(0))) { + changed_ = true; + WithInsertPoint guard(n); + n->output()->replaceAllUsesWith( + graph_->insertConstant(static_cast(*maybe_len))); + } + } else if (n->kind() == prim::If) { + IfView if_n(n); + bool has_cond_ref = boolean_value_refinements_.count(if_n.cond()) != 0; + ListRefinement empty; + auto true_block_refinements = RefineListLens( + if_n.thenBlock(), + has_cond_ref ? boolean_value_refinements_[if_n.cond()].true_refine() + : empty); + auto false_block_refinements = RefineListLens( + if_n.elseBlock(), + has_cond_ref + ? boolean_value_refinements_[if_n.cond()].false_refine() + : empty); + + joinIfRefinements( + n, + throwing_blocks_, + block_refinements, + true_block_refinements, + false_block_refinements, + boolean_value_refinements_); + } else { + handleCommonRefinentOperators( + n, throwing_blocks_, boolean_value_refinements_); + } + } + active_refinements_.pop_back(); + return block_refinements; + }; + + c10::optional tryFindRefinement(Value* v) { + for (const auto& ref : active_refinements_) { + auto maybe_refinement = ref->find(v); + if (maybe_refinement != ref->end()) { + return maybe_refinement->second; + } + } + return c10::nullopt; + } + + std::shared_ptr graph_; + std::unordered_set mutated_lists_; + // candidate lists for optimizations + std::unordered_set lists_to_refine_; + // A stack of active refinements, one for each block + std::vector active_refinements_; + // A map from Boolean Value * -> associated refinements + std::unordered_map + boolean_value_refinements_; + std::unordered_set throwing_blocks_; + bool changed_ = false; +}; + // This pass only does optimizations on lists which aren't mutated, // so we first use the Alias Db to collect the set of list values // which we shouldn't optimize. struct PeepholeOptimizeListIdiomsImpl { - PeepholeOptimizeListIdiomsImpl(std::shared_ptr graph) + PeepholeOptimizeListIdiomsImpl( + std::shared_ptr graph, + bool refine_list_len) : graph_(std::move(graph)), - aliasDb_(torch::make_unique(graph_)) {} + aliasDb_(torch::make_unique(graph_)), + refine_list_len_(refine_list_len) {} bool run() { collectMutatedLists(graph_->block()); - return runBlock(graph_->block()); + bool changed = runBlock(graph_->block()); + if (refine_list_len_) { + changed |= ListLenRefiner(graph_, mutated_lists_).run(); + } + return changed; } private: @@ -64,11 +200,11 @@ struct PeepholeOptimizeListIdiomsImpl { // only optimizing list ops if (node->inputs().size() == 0 || - !node->inputs().at(0)->type()->cast()) { + !node->input(0)->type()->cast()) { continue; } - auto first_input = node->inputs().at(0); + auto first_input = node->input(0); // only optimizing ops with unmutated lists if (mutated_lists_.count(first_input)) { @@ -85,11 +221,11 @@ struct PeepholeOptimizeListIdiomsImpl { } else if (node->kind() == aten::__getitem__) { auto list_creation_node = first_input->node(); if (list_creation_node->kind() == prim::ListConstruct) { - if (auto index = toIValue(node->inputs().at(1))) { + if (auto index = toIValue(node->input(1))) { size_t list_size = list_creation_node->inputs().size(); if (auto norm_index = normalizeIndex(index->toInt(), list_size)) { node->output()->replaceAllUsesWith( - list_creation_node->inputs().at(*norm_index)); + list_creation_node->input(*norm_index)); changed = true; } } @@ -102,8 +238,7 @@ struct PeepholeOptimizeListIdiomsImpl { continue; } for (size_t i = 0; i < node->outputs().size(); ++i) { - node->output(i)->replaceAllUsesWith( - list_creation_node->inputs().at(i)); + node->output(i)->replaceAllUsesWith(list_creation_node->input(i)); changed = true; } } @@ -111,7 +246,7 @@ struct PeepholeOptimizeListIdiomsImpl { if (node->inputs().size() != 2) { continue; } - auto second_input = node->inputs().at(1); + auto second_input = node->input(1); // already checked first, need to check second if (mutated_lists_.count(second_input)) { continue; @@ -143,10 +278,13 @@ struct PeepholeOptimizeListIdiomsImpl { std::unordered_set mutated_lists_; std::shared_ptr graph_; std::unique_ptr aliasDb_; + bool refine_list_len_; }; -bool PeepholeOptimizeListIdioms(const std::shared_ptr& graph) { - PeepholeOptimizeListIdiomsImpl opt(graph); +bool PeepholeOptimizeListIdioms( + const std::shared_ptr& graph, + bool refine_list_len) { + PeepholeOptimizeListIdiomsImpl opt(graph, refine_list_len); return opt.run(); } diff --git a/torch/csrc/jit/passes/peephole_list_idioms.h b/torch/csrc/jit/passes/peephole_list_idioms.h index 1aafe6f518e6ba..33e44da087effc 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.h +++ b/torch/csrc/jit/passes/peephole_list_idioms.h @@ -7,8 +7,13 @@ namespace jit { // Peephole Optimizes List Ops such as len(li) and li[1]. // Currently this is invoked as part of PeepholeOptimize -// return true if graph is modified -TORCH_API bool PeepholeOptimizeListIdioms(const std::shared_ptr& graph); +// return true if graph is modified. +// If `refine_list_len` is true will attempt to refine the len of lists through +// len comparisons and assertions. This does not generally optimize pytorch +// programs so it is not called by default in PeepholeOptimize. +TORCH_API bool PeepholeOptimizeListIdioms( + const std::shared_ptr& graph, + bool refine_list_len = false); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/peephole_non_tensor.cpp b/torch/csrc/jit/passes/peephole_non_tensor.cpp new file mode 100644 index 00000000000000..795aca7ac1c290 --- /dev/null +++ b/torch/csrc/jit/passes/peephole_non_tensor.cpp @@ -0,0 +1,154 @@ +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +struct PeepholeOptimizeNonTensorImpl { + PeepholeOptimizeNonTensorImpl(const std::shared_ptr& graph) + : graph_(graph) {} + + bool run() { + return optimizeBlock(graph_->block()); + } + + bool optimizeBlock(Block* block) { + bool changed = false; + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { + auto* node = *it; + + for (Block* sub_block : node->blocks()) { + changed |= optimizeBlock(sub_block); + } + + if (node->kind() != prim::Constant) { + WithInsertPoint guard(node); + // Any Value whose type is None should be replaced with a Constant + // This can occur if a module has an optional attribute, and it is + // initialized as None. + for (Value* output : node->outputs()) { + if (output->type()->cast()) { + output->replaceAllUsesWith(graph_->insertConstant(IValue())); + changed = true; + } + } + } + // XXX: remember that if you want to simplify an expression by combining + // multiple nodes into a different one, then you need to check that they + // all belong to the given block + // TODO: this doesn't work with Scalar-Tensor ops! We should + // canonicalize those + if (node->kind() == prim::If) { + IfView n(node); + // this handles redundant short circuits like "x and True" or "x or + // False" + for (size_t i = 0; i < n.outputs().size(); ++i) { + if (n.outputs().at(i)->type() != BoolType::get()) { + continue; + } + bool true_val = + constant_as(n.thenOutputs().at(i)).value_or(false); + bool false_val = + constant_as(n.elseOutputs().at(i)).value_or(true); + // if an if node's output equals its condition replace output with + // condition + if (true_val && !false_val) { + GRAPH_UPDATE( + "Replacing ", + n.outputs().at(i)->debugName(), + " (True or False) with ", + n.cond()->debugName()); + n.outputs().at(i)->replaceAllUsesWith(n.cond()); + changed = true; + } + } + } else if ( + node->kind() == aten::__is__ || node->kind() == aten::__isnot__) { + // if we are comparing a None value with a value that can't be None + // replace the output with true if node is __isnot__ or false if node is + // __is__ + AT_ASSERT(node->inputs().size() == 2); + for (size_t check_none_index : {0, 1}) { + bool input_must_be_none = + node->inputs().at(check_none_index)->mustBeNone(); + bool other_must_not_be_none = + node->inputs().at(1 - check_none_index)->mustNotBeNone(); + if (input_must_be_none && other_must_not_be_none) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant( + node->kind() == aten::__isnot__); + GRAPH_UPDATE( + "Folding ", getHeader(node), " to ", output->debugName()); + node->output()->replaceAllUsesWith(output); + changed = true; + } + } + } else if ( + node->kind() == prim::unchecked_unwrap_optional || + node->kind() == aten::_unwrap_optional) { + // we are unwrapping an input that can't be None, remove the unwrap + auto input = node->input(); + if (input->mustNotBeNone()) { + GRAPH_UPDATE( + "Unwrapping ", + getHeader(node), + " as ", + node->input(), + " can't be optional"); + node->output()->replaceAllUsesWith(node->input()); + changed = true; + } + } else if (node->kind() == prim::unchecked_cast) { + // unchecked_cast is not generated for tensor properties, so we are not + // losing anything by calling unshapedType here + auto input_type = unshapedType(node->input()->type()); + auto output_type = unshapedType(node->output()->type()); + if (input_type->isSubtypeOf(output_type)) { + GRAPH_UPDATE( + "Removing ", + getHeader(node), + " as input type subtypes output type"); + node->output()->replaceAllUsesWith(node->input()); + changed = true; + } + } else if (node->kind() == aten::ne || node->kind() == aten::eq) { + if (node->inputs().size() != 2 || + node->inputs().at(0) != node->inputs().at(1)) { + continue; + } + auto inp_kind = node->inputs().at(0)->type()->kind(); + // only handling common immutable types here because other types like + // Tensor or list of Tensor might throw on aten::eq + switch (inp_kind) { + case TypeKind::BoolType: + case TypeKind::IntType: + case TypeKind::FloatType: { + WithInsertPoint guard(node); + node->output()->replaceAllUsesWith( + graph_->insertConstant(node->kind() == aten::eq)); + changed = true; + } + default: + break; + } + } + } + return changed; + } + + private: + std::shared_ptr graph_; +}; + +bool PeepholeOptimizeNonTensor(const std::shared_ptr& graph) { + PeepholeOptimizeNonTensorImpl peephole(graph); + bool changed = peephole.run(); + GRAPH_DUMP("After PeepholeOptimize: ", graph); + return changed; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/peephole_non_tensor.h b/torch/csrc/jit/passes/peephole_non_tensor.h new file mode 100644 index 00000000000000..1e4daebd060cc9 --- /dev/null +++ b/torch/csrc/jit/passes/peephole_non_tensor.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// return true if graph is modified +// Optimizing General Graph Patterns that +// are not covered in peephole.cpp and peephole_list_idioms +TORCH_API bool PeepholeOptimizeNonTensor(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/subgraph_rewrite.cpp b/torch/csrc/jit/passes/subgraph_rewrite.cpp index 45d07fe75fed01..85c4d7cb1aaaac 100644 --- a/torch/csrc/jit/passes/subgraph_rewrite.cpp +++ b/torch/csrc/jit/passes/subgraph_rewrite.cpp @@ -6,6 +6,31 @@ namespace torch { namespace jit { +namespace { +void update_source_range_and_cs_ptr( + const std::set& input_nodes, + const Match& m, + std::unordered_map& pattern_node_map) { + // pattern_node_map, maps nodes of the replacement graph + // to the nodes of the pattern graph. + // Now we iterate over each node of the replacement graph + // and find the corresponding pattern node in the match. + // The matched's node's source range and callstack is then + // used to update replacement node's source range and callstack + for (auto& it : pattern_node_map) { + Node* replacement_node = it.first; + Node* pattern_node = it.second; + if (!input_nodes.count(pattern_node)) { + Node* orig_node = m.nodes_map.at(pattern_node); + replacement_node->setSourceRange(orig_node->sourceRange()); + if (orig_node->callstack()) { + replacement_node->setCallStack(orig_node->callstack().value()); + } + } + } +} +} // namespace + void SubgraphRewriter::RegisterDefaultPatterns() { // TODO: Add actual patterns (like Conv-Relu). RegisterRewritePattern( @@ -17,13 +42,17 @@ graph(%x, %w, %b): R"IR( graph(%x, %w, %b): %r = aten::convrelu(%x, %w, %b) - return (%r))IR"); + return (%r))IR", + {{"r", "c"}}); } void SubgraphRewriter::RegisterRewritePattern( const std::string& pattern, - const std::string& replacement) { - RewritePatternDescr d = {pattern, replacement}; + const std::string& replacement, + const std::vector>& value_name_pairs) { + std::unordered_map value_name_map( + value_name_pairs.begin(), value_name_pairs.end()); + RewritePatternDescr d = {pattern, replacement, value_name_map}; patterns_.push_back(d); } @@ -56,7 +85,28 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph( parseIR(pattern.pattern, &pattern_graph, vmap); Graph replacement_graph; - parseIR(pattern.replacement, &replacement_graph); + std::unordered_map vmap_replacement; + parseIR(pattern.replacement, &replacement_graph, vmap_replacement); + + // First construct map of Node*-to-Node* + // This maps Nodes in replacement graph to nodes in pattern graph + // given the value_name_map, which maps value names from repalcement + // pattern to value name in pattern + std::unordered_map pattern_node_map; + std::set pattern_input_nodes; + for (auto& it : vmap_replacement) { + const auto& replacement_value_name = it.first; + Node* replacement_value_node = it.second->node(); + if (pattern.value_name_map.count(replacement_value_name)) { + const auto& pattern_value_name = + pattern.value_name_map.at(replacement_value_name); + TORCH_CHECK( + vmap.count(pattern_value_name), + "Value must be found in the replacement graph."); + Node* pattern_value_node = vmap.at(pattern_value_name)->node(); + pattern_node_map.emplace(replacement_value_node, pattern_value_node); + } + } const auto& matches = findPatternMatches(pattern_graph, *graph); for (const Match& match : matches) { @@ -105,6 +155,11 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph( continue; } + // Before rewriting the graph, update source range and callstack + // info of the replacement pattern graph so that the rewritten graph + // has the updated info + update_source_range_and_cs_ptr( + pattern_input_nodes, match, pattern_node_map); // Insert a clone of replacement subgraph. // `inputs` vector holds values that we would use as incoming values to the // new subgraph, and we will get `new_outputs` vector containing values diff --git a/torch/csrc/jit/passes/subgraph_rewrite.h b/torch/csrc/jit/passes/subgraph_rewrite.h index ee84f58ad087f9..d932c0c1f74fa7 100644 --- a/torch/csrc/jit/passes/subgraph_rewrite.h +++ b/torch/csrc/jit/passes/subgraph_rewrite.h @@ -78,12 +78,17 @@ class TORCH_API SubgraphRewriter { * The method takes two parameters specifying the pattern: * \p PATTERN - IR string representing the pattern subgraph. * \p REPLACEMENT - IR string representing the replacement subgraph. + * \p value name map - vector of pairs mapping values in the replacement graph + * to the values in the pattern graph. Used for preserving source range info + * across graph rewrite. * * See examples of pattern registering in `RegisterDefaultPatterns`. */ void RegisterRewritePattern( const std::string& pattern, - const std::string& replacement); + const std::string& replacement, + const std::vector>& value_name_pair = + {}); private: std::vector patterns_; @@ -105,6 +110,7 @@ class TORCH_API SubgraphRewriter { struct RewritePatternDescr { std::string pattern; std::string replacement; + std::unordered_map value_name_map; }; } // namespace jit diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp new file mode 100644 index 00000000000000..89794a90025ea2 --- /dev/null +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -0,0 +1,296 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* +XXX: this is still in prototype phase and has much work left to do, including +but not limited to: +- Bind shape functions for operators in C+ +- Make classes of operators share the same shape function (e.g. pointwise, +broadcast two inputs) +- Refactor APIs +- Only iteratively optimize shape function while a change has been made +- Add decent coverage of common ops +- Add shape analysis pass on Graph that handles Ifs and Loops +- Allow concurrent reads to the operator map +- Successive applications of same inputs to same shape function (e.g. series of +pointwise ops) +- Better support for Symbolic Shapes (additional optimizations, etc) +- Supporting returning partially evaluated shape compute graph +*/ + +namespace torch { +namespace jit { + +// TODO: better registration mechanism +std::mutex lock; +std::unordered_map> operator_functions; + +c10::optional normIndex(int64_t index, size_t len) { + if (index < 0) { + index = index + len; + } + if (index >= 0 && index < static_cast(len)) { + return index; + } else { + return c10::nullopt; + } +} + +void replaceWithIValue(Value* v, IValue val) { + WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin()); + v->replaceAllUsesWith(v->owningGraph()->insertConstant(val)); +} + +// Symbolic Shape Analysis works through iteratively partially evaluating +// a TorchScript shape compute graph by inputting properties from input +// Tensors. We can substitute in properties like `len(x)` and `x[1]` +// if they are statically on the input Tensors. We can also use +// assertions like `assert len(x) == 4` in order to refine the input +// length and unroll loops over its elements. We iteratively optimize and +// substitute in properties until we are unable to make any further +// optimizations. Finally, we try to extract Tensor properties from the output. +// For instance `return [1, 2, inp[2] + 1, inp[3]]` we know that the ouptut +// will be length 4 with first two dimensions equal to 1 and 2. +// It is not implemented yet but in the future we will also be able to +// infer that the 4th dimension will have the same symbolic shape as inp[3] + +struct SymbolicShapeAnalyzer { + SymbolicShapeAnalyzer(Node* n, std::shared_ptr shape_compute_graph) + : graph_(shape_compute_graph->copy()), node_(n) { + for (size_t i = 0; i < node_->inputs().size(); i++) { + auto type = node_->input(i)->type(); + if (auto tt = type->castRaw()) { + c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes(); + if (symbolic_shapes.isComplete()) { + replaceWithIValue( + graph_->inputs().at(i), *tt->sizes().concrete_sizes()); + continue; + } + // we can't optimize a tensor without fixed rank + if (symbolic_shapes.rank()) { + node_input_tensor_indices.push_back(i); + } + } else if ( + type->cast() && + type->cast()->getElementType()->cast()) { + TORCH_INTERNAL_ASSERT(false); // not handled yet + } else { + if (auto ival = toIValue(node_->input(i))) { + replaceWithIValue(graph_->inputs().at(i), *ival); + } + } + } + } + + c10::SymbolicShape run() { + // TODO: only run while the last iteration has made a change + size_t num_optimization_iters = 6; + for (size_t i = 0; i < num_optimization_iters; i++) { + // XXX: we cannot substitute symbolic dims before passes like constant + // propagation, or we might inadvertently use them in arithmetic or + // other operators + substituteInputTensorProperties(/*substitute_symbolic_dims*/ false); + LowerSimpleTuples(graph_); + RemoveListMutation(graph_); + UnrollConstantLoops(graph_); + ConstantPropagation(graph_); + PeepholeOptimizeNonTensor(graph_); + PeepholeOptimizeListIdioms(graph_, /*refine_list_len*/ true); + RefineIntegerValues(graph_); + ConstantPropagation(graph_); + EliminateCommonSubexpression(graph_); + } + substituteInputTensorProperties(/*substitute_symbolic_dims*/ true); + // XXX: do not run any passes after we have substituted in symbolic + // dimension value, we do it so they can be easily extracted into the output + // shape + return extractOutputShape(); + } + + private: + void substituteInputTensorProperties(bool substitute_symbolic_dims) { + // here we iteratively substitute properties of the node's input tensors + // into the shape compute graph. in addition to direct constants we can + // substitute, like len(inp) or inp[0] if the tensor has fixed length + // or first dimension, we also try to resolve symbolic shapes of the same + // symbolic value to the same Value * in the shape compute graph. + // for the shape logic: + // dim1 = inp1[0]; + // dim2 = inp2[0]; + // return dim1 if dim2 == 1 else dim2; + // if we see that inp1[0] and inp2[0] both have the same symbolic shape + // value, then it is a valid transformation to replace dim2 with dim1 or + // vice versa. to do this we collect all Value * for a particular symbolic + // dimension value and then Value * with their dominator of the same + // symbolic dimension value in the example above, this allows us to infer + // that the output will be the symbolic dimension value of dim1 + // if `substitute_symbolic_dims` is true, then we insert list accesses + // which resolve to symbolic dimension values as constants in the graph + // because symbolic dimensions are represented as negative numbers and + // are not real values, this is only safe to do if you are not running + // any further optimizations. representing them as constants in the graph + // makes extracting output shapes with symbolic dimensions possible. + + std::unordered_map> symbolic_shape_map; + + for (auto tensor_index : node_input_tensor_indices) { + auto tensor_value = node_->inputs().at(tensor_index); + auto tensor_shape = + tensor_value->type()->expect()->symbolic_sizes(); + TORCH_INTERNAL_ASSERT(tensor_shape.rank().has_value()); + + for (const auto& use : graph_->inputs().at(tensor_index)->uses()) { + // TODO: either decompose composite ops like slice or add handling here + switch (use.user->kind()) { + case aten::len: { + size_t len = tensor_shape.rank().value(); + replaceWithIValue(use.user->output(), static_cast(len)); + } break; + case aten::__getitem__: { + auto index = constant_as(use.user->inputs().at(1)); + if (!index) { + continue; + } + auto norm_index = normIndex(*index, *tensor_shape.rank()); + if (!norm_index) { + continue; + } + if (tensor_shape[*norm_index].is_static() || + substitute_symbolic_dims) { + replaceWithIValue( + use.user->output(), tensor_shape[*norm_index].value()); + } else { + int64_t symbolic_index = tensor_shape[*norm_index].value(); + symbolic_shape_map[symbolic_index].push_back(use.user->output()); + } + } + } + } + + for (const auto& symbolic_set : symbolic_shape_map) { + mergeSymbolicShapeSets(symbolic_set.second); + } + } + } + + void mergeSymbolicShapeSets(const std::vector& symbolic_set) { + // `symbolic_set` represents a set of Value * which are all equal + // to each other. Here, we optimize the graph by replacing values + // in the set with other dominating values. + // in the following example, where a, b and c are all in the same + // symbolic set: + // if cond: + // a = li[0] + // b = li[1] + // return [a, b] + // else: + // c = li[0] + // return [c, c] + // we can replace `b` with `a` because it is dominated by `a`, + // but we cannot replace `c` with another dominating value + + // there are ways to compute this more efficiently but typically number of + // Values for each symbolic set is low and this is cheap to run + for (size_t i = 0; i < symbolic_set.size(); ++i) { + Value* v = symbolic_set[i]; + Value* dominating_value = v; + for (size_t j = 0; j < symbolic_set.size(); ++j) { + if (dominating_value->node()->isDominatedBy(symbolic_set[j]->node())) { + dominating_value = symbolic_set[j]; + } + } + if (dominating_value != v) { + v->replaceAllUsesWith(dominating_value); + } + } + } + + c10::SymbolicShape extractOutputShape() { + TORCH_INTERNAL_ASSERT(graph_->outputs().size() == 1); + auto output = graph_->outputs().at(0); + TORCH_INTERNAL_ASSERT( + output->type()->cast() && + output->type()->cast()->getElementType()->cast()); + if (output->node()->kind() == prim::Constant) { + auto int_list = toIValue(output)->toIntVector(); + return c10::SymbolicShape(int_list); + } + // If it is not a single list construct or constant, bail, + // otherwise we cannot analyze its output and it might be modified + if (output->node()->kind() != prim::ListConstruct || + output->uses().size() != 1) { + return c10::SymbolicShape(); + } + Node* list_construct = output->node(); + std::vector> output_shape; + for (Value* input : list_construct->inputs()) { + output_shape.push_back(constant_as(input)); + } + return c10::SymbolicShape(output_shape); + } + + // node input indices that are TensorType and we need to iteratively + // substitute properties of. We only substitute properties + // of TensorTypes with a fixed dimension but not a complete shape, + // because a complete shape we can completely replace with a constant + // and non-fixed dimensions we cannot reason about at all + // TODO: might be cleaner to store as a pair of index -> symbolic shape + // but there were weird lifetime issues + std::vector node_input_tensor_indices; + std::shared_ptr graph_; + Node* node_; +}; + +void PropagateShapesWithShapeFunction( + Node* n, + std::shared_ptr& shape_compute_graph) { + c10::SymbolicShape out = SymbolicShapeAnalyzer(n, shape_compute_graph).run(); + n->output()->setType( + n->output()->type()->expect()->withSymbolicShapes(out)); +} + +void RegisterOperatorShapeFunction(Node* n, std::shared_ptr& graph) { + std::lock_guard guard(lock); + if (!n->maybeSchema()) { + return; + } + if (operator_functions.count(toString(n->schema()))) { + return; + } + operator_functions[toString(n->schema())] = graph; +} + +void PropagateShapesOnGraph(std::shared_ptr& graph) { + std::lock_guard guard(lock); + for (Node* n : graph->nodes()) { + if (n->maybeSchema()) { + if (operator_functions.count(toString(n->schema()))) { + PropagateShapesWithShapeFunction( + n, operator_functions[toString(n->schema())]); + } + } + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.h b/torch/csrc/jit/passes/symbolic_shape_analysis.h new file mode 100644 index 00000000000000..9a61af28fbc606 --- /dev/null +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { + +// CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE + +TORCH_API void RegisterOperatorShapeFunction( + Node* n, + std::shared_ptr& graph); + +TORCH_API void PropagateShapesOnGraph(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 6c4a8bba882a79..e48968e8fa05b1 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -139,6 +139,7 @@ static const OperatorSet& supported_eltwise_set() { "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", + "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "aten::gelu(Tensor self) -> Tensor", "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "aten::neg(Tensor self) -> Tensor", @@ -850,12 +851,18 @@ class TensorExprFuser { } return true; } + bool allShapesAreKnown(Node* node) { // TODO: Relax the checks to support dynamic shapes for (Value* input : node->inputs()) { if (!shapeIsKnown(input)) { return false; } + if (input->node()->kind() == prim::ListConstruct) { + if (!allShapesAreKnown(input->node())) { + return false; + } + } } for (Value* output : node->outputs()) { if (!shapeIsKnown(output)) { diff --git a/torch/csrc/jit/passes/value_refinement_utils.cpp b/torch/csrc/jit/passes/value_refinement_utils.cpp new file mode 100644 index 00000000000000..2c450e56506177 --- /dev/null +++ b/torch/csrc/jit/passes/value_refinement_utils.cpp @@ -0,0 +1,241 @@ +#include + +namespace torch { +namespace jit { + +// [value refinement algorithm] + +// When a comparison like `cond = len(x) == 4` or `cond = len(x) != 4` is made, +// `cond` value carries information (refinements) about the len of `x`. +// When `cond` is used as the conditional of an if statement, the information +// it carries for its true value can be inserted into the true block +// and the same for its false value. +// For something like `y = len(x) if len(x) == 1 else 1`, in the true branch +// we can replace len(x) with 1 because the true refinements from `len(x) == 1` +// will be present in the true block. +// Additionally, we can optimize something like: +// if len(x) != 4: +// raise Exception(...) +// return len(x) +// Because the true block always throws, whatever refinements exist in the false +// block become present in the owning block of the if node. We can also merge +// refinements carried by two different booleans across an if node join by +// taking the intersections of their refinements. +// if cond: +// z = len(x) == 4 and len(y) == 5 +// else: +// z = len(x) == 4 +// Here, z's true value will refine the len(x) to 4, but not len(y). +// If the code was written as: +// if cond: +// z = len(x) == 4 and len(y) == 5 +// else: +// z = False +// +// Then z's true value would refine x and y, because if z is true it had to have +// come from the true block. Code that is written with `and` or `or` will +// desugar to something similar. Additionally, any True refinements that were +// present on `cond` can also be associated with the if node True output value. + +// The intersection of the refinements is the Value* which are in both +// refinements and are refined to the same length +// in an example like: +// if cond: +// x = len(a) == 4 and len(b) == 5 +// else: +// x = len(a) == 4 +// For the x output of the node we take the intersection between +// the refinements stored on each block output, which will result +// in only the refinement of len(a) == 4 +ListRefinement intersectRefinements( + const ListRefinement& ref1, + const ListRefinement& ref2) { + ListRefinement out; + for (const auto& pair : ref1) { + auto val2 = ref2.find(pair.first); + if (val2 != ref2.end() && val2->second == pair.second) { + out[pair.first] = pair.second; + } + } + return out; +} + +// To union, just take all refinements from both inputs. We do not need to worry +// about len refinements disagreeing because a path like `if len(x) == 4 and +// len(x) == 5` will never be taken +// in an example like: +// if len(a) == 5: +// x = len(b) == 4 +// else: +// x = False +// For the output x Value, if is true then the refinements present in the true +// block must also be true, so we take the union of `len(a) == 5` and len(b) == +// 4` and assign them to true refinements of the output x value. This is a very +// common pattern in desugaring of `and` or `or` boolean expressions +ListRefinement unionRefinements( + const ListRefinement& ref1, + const ListRefinement& ref2) { + ListRefinement out = ref1; + out.insert(ref2.begin(), ref2.end()); + return out; +} + +void joinIfRefinements( + Node* if_node, + std::unordered_set& throwing_blocks, + ListRefinement& curr_block_refinements, + ListRefinement& true_block_refinements, + ListRefinement& false_block_refinements, + std::unordered_map& + boolean_value_refinements) { + IfView if_n(if_node); + Block* b = if_node->owningBlock(); + + bool true_block_throws = throwing_blocks.count(if_n.thenBlock()); + bool false_block_throws = throwing_blocks.count(if_n.elseBlock()); + + // if one block throws, the refinements for the other block + // become present in the current block, and all bool outputs + // of the if node take their refinements from non throwing block + // output + + if (true_block_throws || false_block_throws) { + if (true_block_throws && false_block_throws) { + throwing_blocks.insert(b); + return; + } + if (true_block_throws) { + curr_block_refinements.insert( + false_block_refinements.begin(), false_block_refinements.end()); + } else { + curr_block_refinements.insert( + true_block_refinements.begin(), true_block_refinements.end()); + } + Block* non_throwing_block = + true_block_throws ? if_node->blocks().at(1) : if_node->blocks().at(0); + for (size_t i = 0; i < if_n.outputs().size(); ++i) { + if (boolean_value_refinements.count( + non_throwing_block->outputs().at(i))) { + boolean_value_refinements[if_node->outputs().at(i)] = + boolean_value_refinements[non_throwing_block->outputs().at(i)]; + } + } + return; + } + + for (size_t i = 0; i < if_n.outputs().size(); ++i) { + if (!(if_n.outputs().at(i)->type() == BoolType::get())) { + return; + } + Value* true_v = if_n.thenOutputs().at(i); + Value* false_v = if_n.elseOutputs().at(i); + + if (!boolean_value_refinements.count(true_v) && + !boolean_value_refinements.count(false_v) && + !constant_as(true_v) && !constant_as(false_v)) { + return; + } + + // if either block has a constant bool output, e.g. `true` on the + // true block, then for the `false` value we can take the false + // refinements present on the false block and from the other block + // output value bc if the output is false it had to have come from the + // false block. if len(a) == 5: + // x = len(b) == 4 + // else: + // x = False + // if x is true, then we know both len(a) == 5 and len(b) == 4 + // + // if neither block has a constant bool value, we just take the + // intersection of the refinements from boolean outputs. + // if cond: + // x = len(a) == 4 and len(b) == 5 + // else: + // x = len(a) == 4 + // here, we know if x is true, then len(a) == 4, but not len(b) + // == 5, because that refinement is not present in the true block. + // TODO: could also take intersection of refinements present in + // both blocks, but it's not a real use case. + + // boolean_value_refinements[value] is safe to access because + // BooleanRefinementMapping has a default constructor + + BooleanRefinementMapping out; + if (auto maybe_bool = constant_as(true_v)) { + if (*maybe_bool) { + out = BooleanRefinementMapping::FalseRefinements(unionRefinements( + boolean_value_refinements[false_v].false_refine(), + false_block_refinements)); + } else { + out = BooleanRefinementMapping::TrueRefinements(unionRefinements( + boolean_value_refinements[false_v].true_refine(), + false_block_refinements)); + } + } else if (auto maybe_bool = constant_as(false_v)) { + if (*maybe_bool) { + out = BooleanRefinementMapping::FalseRefinements(unionRefinements( + boolean_value_refinements[true_v].false_refine(), + true_block_refinements)); + } else { + out = BooleanRefinementMapping::TrueRefinements(unionRefinements( + boolean_value_refinements[true_v].true_refine(), + true_block_refinements)); + } + } else if ( + boolean_value_refinements.count(true_v) && + boolean_value_refinements.count(false_v)) { + out = boolean_value_refinements[true_v].intersectBooleanRefinementMapping( + boolean_value_refinements[false_v]); + } + boolean_value_refinements[if_n.outputs().at(i)] = out; + } +} + +bool handleCommonRefinentOperators( + Node* n, + std::unordered_set& throwing_blocks, + std::unordered_map& info) { + if (n->kind() == prim::RaiseException) { + throwing_blocks.insert(n->owningBlock()); + return true; + } + if (n->kind() == aten::__not__ && + n->inputs().at(0)->type()->cast()) { + // __not__(inp) -> reverse refinements + if (info.count(n->input())) { + auto& input_ref = info[n->input()]; + info[n->output()] = BooleanRefinementMapping( + input_ref.false_refine(), input_ref.true_refine()); + } + return true; + } + if (n->matches("aten::eq(bool a, bool b) -> bool") || + (n->matches("aten::ne(bool a, bool b) -> bool"))) { + for (size_t const_index : {0, 1}) { + if (n->input(const_index)->node()->kind() != prim::Constant) { + continue; + } + auto const_input = constant_as(n->input(const_index)).value(); + auto non_const_input = n->input(1 - const_index); + if (!info.count(non_const_input)) { + continue; + } + // value == False / value != True -> equivalent to __not__ value + // value == True / value != False -> equivalent to value + auto& input_ref = info[non_const_input]; + if ((!const_input && n->kind() == aten::eq) || + (const_input && n->kind() == aten::ne)) { + info[n->output()] = BooleanRefinementMapping( + input_ref.false_refine(), input_ref.true_refine()); + } else { + info[n->output()] = BooleanRefinementMapping( + input_ref.true_refine(), input_ref.false_refine()); + } + } + return true; + } + return false; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/value_refinement_utils.h b/torch/csrc/jit/passes/value_refinement_utils.h new file mode 100644 index 00000000000000..3d6f705d3d6103 --- /dev/null +++ b/torch/csrc/jit/passes/value_refinement_utils.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +// Refine from Value of type List -> len of list +// If a refinement mapping of List Value * -> len is present in a block +// the list is guaranteed to be that length +// TODO: vector may be faster +using ListRefinement = std::unordered_map; + +TORCH_API ListRefinement +intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2); + +TORCH_API ListRefinement +unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2); + +// Represents the refinement information that can be carried on a boolean +struct BooleanRefinementMapping { + BooleanRefinementMapping( + ListRefinement true_refine, + ListRefinement false_refine) + : true_refine_(std::move(true_refine)), + false_refine_(std::move(false_refine)){}; + BooleanRefinementMapping() = default; // empty + + static BooleanRefinementMapping FalseRefinements( + ListRefinement false_refine) { + return BooleanRefinementMapping({}, std::move(false_refine)); + } + + static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) { + return BooleanRefinementMapping(std::move(true_refine), {}); + } + + BooleanRefinementMapping intersectBooleanRefinementMapping( + BooleanRefinementMapping& other) { + return BooleanRefinementMapping( + intersectRefinements(true_refine_, other.true_refine()), + intersectRefinements(false_refine_, other.false_refine())); + } + + ListRefinement& true_refine() { + return true_refine_; + } + + ListRefinement& false_refine() { + return false_refine_; + } + + private: + ListRefinement true_refine_; + ListRefinement false_refine_; +}; + +TORCH_API void joinIfRefinements( + Node* if_node, + std::unordered_set& throwing_blocks, + ListRefinement& curr_block_refinements, + ListRefinement& true_block_refinements, + ListRefinement& false_block_refinements, + std::unordered_map& info); + +// handles adding blocks to throwing blocks and propagating refinements via +// boolean comparisons +TORCH_API bool handleCommonRefinentOperators( + Node* n, + std::unordered_set& throwing_blocks, + std::unordered_map& info); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 37b73970cd80af..bb0c03bd512b35 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -68,6 +70,7 @@ #include #include #include +#include #include #include #include @@ -161,7 +164,15 @@ void initJITBindings(PyObject* module) { ONNXAssignOutputShape(graph, tensors, desc, onnx_shape_inference); }) .def("_jit_pass_lower_all_tuples", LowerAllTuples) + .def( + "_new_symbolic_shape_symbol", + []() { return c10::ShapeSymbol::newSymbol().value(); }) + .def( + "_jit_register_operator_shape_function", + RegisterOperatorShapeFunction) + .def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph) .def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution) + .def("_jit_pass_integer_value_refinement", RefineIntegerValues) .def( "_jit_pass_onnx_fold_if", [](std::shared_ptr& graph) { @@ -370,11 +381,19 @@ void initJITBindings(PyObject* module) { "_jit_pass_custom_pattern_based_rewrite_graph", [](const std::string& pattern, const std::string& fused_node_name, - std::shared_ptr g) { + std::shared_ptr g, + const std::vector>& + value_name_pairs) { SubgraphRewriter subgraph_rewriter; - subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name); + subgraph_rewriter.RegisterRewritePattern( + pattern, fused_node_name, value_name_pairs); subgraph_rewriter.runOnGraph(g); - }) + }, + py::arg("pattern"), + py::arg("fused_node_name"), + py::arg("g"), + py::arg("value_name_pairs") = + std::vector>()) .def( "_jit_pass_remove_inplace_ops", [](const std::shared_ptr& g) { return RemoveInplaceOps(g); }) @@ -398,6 +417,16 @@ void initJITBindings(PyObject* module) { }, py::arg("graph"), py::arg("addmm_fusion_enabled") = false) + .def( + "_jit_pass_peephole_list_idioms", + [](const std::shared_ptr& g, bool refine_list_len) { + return PeepholeOptimizeListIdioms(g, refine_list_len); + }, + py::arg("graph"), + py::arg("refine_list_len") = false) + .def( + "_jit_pass_refine_integer_values", + [](std::shared_ptr& g) { return RefineIntegerValues(g); }) .def( "_jit_pass_fuse_addmm", [](std::shared_ptr& g) { return FuseAddMM(g); }) diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 02af155fece738..f01400d32a22b3 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -740,6 +740,31 @@ void initPythonIRBindings(PyObject* module_) { } return py::none(); }) + .def( + "symbolic_sizes", + [](Type& t) -> py::object { + if (auto ptt = t.expect()) { + auto ss = ptt->symbolic_sizes(); + if (!ss.rank().has_value()) { + return py::none(); + } + + std::vector ss_vals; + for (size_t i = 0; i < *ss.rank(); ++i) { + ss_vals.push_back(ss.at(i).value()); + } + return py::cast(ss_vals); + } + return py::none(); + }) + .def( + "with_sizes", + [](Type& t, std::vector> sizes) -> py::object { + if (auto ptt = t.expect()) { + return py::cast(ptt->withSymbolicShapes(sizes)); + } + return py::none(); + }) .def( "varyingSizes", [](Type& t) -> py::object { diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 75e159b4b1a1d1..fbb5e232aeb4fd 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1102,6 +1102,9 @@ std::shared_ptr toSugaredValue( } } + bool isRpcAvailable = py::cast( + py::module::import("torch.distributed.rpc").attr("is_available")()); + if (auto callee = as_function(obj)) { return std::make_shared(callee->function_); } else if (py::isinstance(obj)) { @@ -1125,17 +1128,20 @@ std::shared_ptr toSugaredValue( #ifdef USE_RPC // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. } else if ( + isRpcAvailable && obj.ptr() == - py::module::import("torch.distributed.rpc").attr("rpc_async").ptr()) { + py::module::import("torch.distributed.rpc").attr("rpc_async").ptr()) { return SpecialFormValue::create(prim::rpc_async); } else if ( + isRpcAvailable && obj.ptr() == - py::module::import("torch.distributed.rpc").attr("rpc_sync").ptr()) { + py::module::import("torch.distributed.rpc").attr("rpc_sync").ptr()) { return SpecialFormValue::create(prim::rpc_sync); } else if ( + isRpcAvailable && // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. obj.ptr() == - py::module::import("torch.distributed.rpc").attr("remote").ptr()) { + py::module::import("torch.distributed.rpc").attr("remote").ptr()) { return SpecialFormValue::create(prim::rpc_remote); #endif } else if (auto callee = as_module(obj)) { diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 2f2f9f8785b216..02ec8aa7b43412 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -128,7 +129,6 @@ struct PythonResolver : public Resolver { if (classType_ && name == classname_) { return classType_; } - pybind11::gil_scoped_acquire ag; py::object obj = rcb_(name); if (obj.is(py::none())) { @@ -1770,6 +1770,15 @@ void initJitScriptBindings(PyObject* module) { std::istringstream in(buffer); return _get_model_bytecode_version(in); }); + py::class_(m, "OperatorInfo") + .def_readonly("num_schema_args", &OperatorInfo::num_schema_args); + m.def("_get_model_ops_and_info", [](const std::string& filename) { + return _get_model_ops_and_info(filename); + }); + m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) { + std::istringstream in(buffer); + return _get_model_ops_and_info(in); + }); m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) { return debugMakeSet(torch::jit::mobile::_export_operator_list(sm)); }); diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index cfdb641fc46bf3..77a2042dc5067f 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #ifdef USE_RPC @@ -229,6 +230,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // std::cout << "RUNNING "; // frames.back().function->dump(std::cout, frame.pc); Instruction inst = frame.function->instructions_[frame.pc]; + profiling::InstructionSpan instSpan{ + *frame.function->instructions_source()[frame.pc]}; switch (inst.op) { case ENTER: { const auto& obj = peek(stack, 0, 1); diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp index 1df4d3bcefc5d7..e57a612c8b02ac 100644 --- a/torch/csrc/jit/runtime/register_c10_ops.cpp +++ b/torch/csrc/jit/runtime/register_c10_ops.cpp @@ -63,7 +63,7 @@ Operator createOperatorFromC10_withTracingHandledHere( tracer::addInputs(node, args[i].name().c_str(), iter->toBool()); } else if (type->kind() == TypeKind::StringType) { AT_ASSERT(iter->isString()); - tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef()); + tracer::addInputs(node, args[i].name().c_str(), iter->toStringView()); } else if (type->kind() == TypeKind::NumberType) { tracer::addInputs(node, args[i].name().c_str(), iter->toScalar()); } else if (type->kind() == TypeKind::ListType) { diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index 800686dc74c0d9..4b6b182777e8d4 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #include @@ -443,9 +444,19 @@ void listMulIntRight(Stack* stack) { } void listSlice(Stack* stack) { - int64_t step = pop(stack).to(); - int64_t end = pop(stack).to(); - int64_t start = pop(stack).to(); + auto step_val = pop(stack); + auto end_val = pop(stack); + auto start_val = pop(stack); + + // In the future, start and end of list slice will be marked + // as None. To prepare for this change, we make listSlice method + // to be able to handle future models that are scripted with newer + // runtime. + int64_t step = step_val.isInt() ? step_val.to() : 1; + int64_t end = end_val.isInt() ? end_val.to() + : std::numeric_limits::max(); + int64_t start = start_val.isInt() ? start_val.to() : 0; + c10::List list = pop(stack).to>(); const int64_t list_size = list.size(); diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 6b3614f95143e8..4bafbe5bff972e 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -2049,7 +2049,7 @@ TORCH_LIBRARY_IMPL(aten, CatchAll, m) { RegisterOperators reg1( {OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::rangelist(int n) -> int[]"), + TORCH_SELECTIVE_SCHEMA("prim::rangelist(int n) -> int[]"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t n; @@ -2065,7 +2065,7 @@ RegisterOperators reg1( // note: this op needs to share a name with the Scalar -> Tensor conversion // because all _to_tensor conversion have to have the same operator namet OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::NumToTensor.bool(bool a) -> Tensor"), + TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.bool(bool a) -> Tensor"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool b; @@ -2074,20 +2074,20 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::device(str a) -> Device"), + TORCH_SELECTIVE_SCHEMA("aten::device(str a) -> Device"), [](Stack* stack) { push(stack, c10::Device(pop(stack).toStringRef())); }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::percentFormat(str self, ...) -> str"), + TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"), [](Stack* stack) { size_t num_inputs = pop(stack).toInt(); percentFormat(*stack, num_inputs); }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack* stack) { at::Tensor self; @@ -2104,7 +2104,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::requires_grad(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::requires_grad(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2112,7 +2112,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::grad(Tensor a) -> Tensor(*)"), + TORCH_SELECTIVE_SCHEMA("prim::grad(Tensor a) -> Tensor(*)"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2120,7 +2120,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_sparse(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_sparse(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2128,7 +2128,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_sparse_csr(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_sparse_csr(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2136,7 +2136,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_mkldnn(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_mkldnn(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2144,7 +2144,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_mlc(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_mlc(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2152,7 +2152,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_vulkan(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_vulkan(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2160,7 +2160,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_quantized(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_quantized(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2168,7 +2168,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::is_meta(Tensor a) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::is_meta(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2176,7 +2176,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::name(Tensor a) -> str?"), + TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2188,7 +2188,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::index(Device self) -> int?"), + TORCH_SELECTIVE_SCHEMA("prim::index(Device self) -> int?"), [](Stack* stack) { auto d = pop(stack).toDevice(); if (d.has_index()) { @@ -2201,11 +2201,11 @@ RegisterOperators reg1( OperatorGenerator( // TODO return generator object when torchscript supports RNG // first-class - TORCH_SELECTIVE_NAME("aten::manual_seed(int seed) -> ()"), + TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"), [](Stack* stack) { at::manual_seed(pop(stack).toInt()); }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::cuda(Tensor(a) self) -> Tensor(a|b)"), + TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -2213,11 +2213,11 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::AutogradZero() -> Tensor"), + TORCH_SELECTIVE_SCHEMA("prim::AutogradZero() -> Tensor"), [](Stack* stack) { stack->emplace_back(at::Tensor()); }, aliasAnalysisSpecialCase()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "prim::ReductionSizes(int[] size, int[] red_axes, bool keepdim = False) -> int[]"), [](Stack* stack) { bool keepdim = pop(stack).toBool(); @@ -2246,7 +2246,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::BroadcastSizes(...) -> int[]"), + TORCH_SELECTIVE_SCHEMA("prim::BroadcastSizes(...) -> int[]"), [](Stack* stack) { auto num_inputs = pop(stack).toInt(); std::vector size; @@ -2260,7 +2260,7 @@ RegisterOperators reg1( }, aliasAnalysisSpecialCase()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::warn(str message, int stacklevel=2) -> ()"), [](Stack* stack) { TORCH_CHECK( @@ -2269,7 +2269,7 @@ RegisterOperators reg1( aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "onnx::Reshape(Tensor input, Tensor shape) -> Tensor"), [](Stack* stack) { at::Tensor input, shape; @@ -2281,7 +2281,7 @@ RegisterOperators reg1( }, aliasAnalysisSpecialCase()), OperatorGenerator( - TORCH_SELECTIVE_NAME("onnx::Shape(Tensor t) -> Tensor"), + TORCH_SELECTIVE_SCHEMA("onnx::Shape(Tensor t) -> Tensor"), [](Stack* stack) { auto t = pop(stack).toTensor(); at::IntArrayRef sizes = t.sizes(); @@ -2295,7 +2295,7 @@ RegisterOperators reg1( }, aliasAnalysisSpecialCase()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::AutogradAnyNonZero(...) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::AutogradAnyNonZero(...) -> bool"), [](Stack* stack) { auto num_inputs = pop(stack).toInt(); bool result = false; @@ -2323,7 +2323,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::AutogradAllZero(...) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::AutogradAllZero(...) -> bool"), [](Stack* stack) { auto num_inputs = pop(stack).toInt(); bool result = true; @@ -2339,7 +2339,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::AutogradAllNonZero(...) -> bool"), + TORCH_SELECTIVE_SCHEMA("prim::AutogradAllNonZero(...) -> bool"), [](Stack* stack) { auto num_inputs = pop(stack).toInt(); bool result = true; @@ -2355,7 +2355,7 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::AutogradAdd(Any a, Any b) -> Any"), + TORCH_SELECTIVE_SCHEMA("prim::AutogradAdd(Any a, Any b) -> Any"), [](Stack* stack) { at::Tensor a, b; pop(stack, a, b); @@ -2373,7 +2373,7 @@ RegisterOperators reg1( }, aliasAnalysisSpecialCase()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::_size_if_not_equal(int[] self_size, int[] other_size) -> int[]?"), [](Stack* stack) { IValue self_size, other_size; @@ -2388,7 +2388,8 @@ RegisterOperators reg1( }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::_unwrap_optional(t(a)? optional) -> t(a)"), + TORCH_SELECTIVE_SCHEMA( + "aten::_unwrap_optional(t(a)? optional) -> t(a)"), [](Stack* stack) { auto val = pop(stack); TORCH_CHECK(!val.isNone(), "Unwrapping null optional"); @@ -2405,28 +2406,28 @@ void hashValue(Stack* stack) { RegisterOperators reg2({ // registered as Any[] so that heterogenous tuples can be called with len() OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::len.any(Any[] a) -> int"), + TORCH_SELECTIVE_SCHEMA("aten::len.any(Any[] a) -> int"), listLen, aliasAnalysisFromSchema()), // these ops have a specialized implementation for the list element type #define CREATE_SPECIALIZED_LIST_OPS(decl_type, value_type) \ OperatorGenerator( \ - TORCH_SELECTIVE_NAME( \ + TORCH_SELECTIVE_SCHEMA( \ "aten::remove." decl_type "(" decl_type \ "[](a!) self, \ " decl_type " el) -> ()"), \ listRemove, \ aliasAnalysisFromSchema()), \ OperatorGenerator( \ - TORCH_SELECTIVE_NAME( \ + TORCH_SELECTIVE_SCHEMA( \ "aten::index.list_" decl_type "(" decl_type \ "[] self, \ " decl_type " el) -> int"), \ listIndex, \ aliasAnalysisFromSchema()), \ OperatorGenerator( \ - TORCH_SELECTIVE_NAME( \ + TORCH_SELECTIVE_SCHEMA( \ "aten::count." decl_type "(" decl_type \ "[] self, \ " decl_type " el) -> int"), \ @@ -2445,97 +2446,100 @@ RegisterOperators reg2({ // `listContains` is not implemented for non-primitive types // TODO: Add List[bool] once .to> doesn't throw an error OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::__contains__.float_list(float[] l, float item) -> bool"), listContains, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sort.int(int[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sort.float(float[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sort.Tensor(Tensor[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sort.bool(bool[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sort.str(str[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sorted.int(int[](a) input) -> (int[])"), + TORCH_SELECTIVE_SCHEMA("aten::sorted.int(int[](a) input) -> (int[])"), listCopyAndSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sorted.float(float[](a) input) -> (float[])"), listCopyAndSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::sorted.Tensor(Tensor[](a) input) -> (Tensor[])"), listCopyAndSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sorted.bool(bool[](a) input) -> (bool[])"), + TORCH_SELECTIVE_SCHEMA( + "aten::sorted.bool(bool[](a) input) -> (bool[])"), listCopyAndSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sorted.str(str[](a) input) -> (str[])"), + TORCH_SELECTIVE_SCHEMA("aten::sorted.str(str[](a) input) -> (str[])"), listCopyAndSort, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::eq.float_list(float[] a, float[] b) -> bool"), listEq, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::eq.Tensor_list(Tensor[] a, Tensor[] b) -> bool"), listEq, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::eq.bool_list(bool[] a, bool[] b) -> bool"), + TORCH_SELECTIVE_SCHEMA( + "aten::eq.bool_list(bool[] a, bool[] b) -> bool"), listEq, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::eq.str_list(str[] a, str[] b) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::eq.str_list(str[] a, str[] b) -> bool"), listEq, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::ne.float_list(float[] a, float[] b) -> bool"), listNe, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::ne.Tensor_list(Tensor[] a, Tensor[] b) -> bool"), listNe, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::ne.bool_list(bool[] a, bool[] b) -> bool"), + TORCH_SELECTIVE_SCHEMA( + "aten::ne.bool_list(bool[] a, bool[] b) -> bool"), listNe, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::ne.str_list(str[] a, str[] b) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::ne.str_list(str[] a, str[] b) -> bool"), listNe, aliasAnalysisFromSchema()), #define DEFINE_CONVERT_BASE_OP(op_name, prefix, char_op) \ OperatorGenerator( \ - TORCH_SELECTIVE_NAME(#op_name "(int i) -> str"), \ + TORCH_SELECTIVE_SCHEMA(#op_name "(int i) -> str"), \ [](Stack* stack) { \ auto i = pop(stack).toInt(); \ std::stringstream ss; \ @@ -2552,7 +2556,7 @@ RegisterOperators reg2({ DEFINE_CONVERT_BASE_OP(aten::oct, "o", std::oct), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::bin(int i) -> str"), + TORCH_SELECTIVE_SCHEMA("aten::bin(int i) -> str"), [](Stack* stack) { auto i = pop(stack).toInt(); std::stringstream ss; @@ -2572,7 +2576,8 @@ RegisterOperators reg2({ aliasAnalysisFromSchema()), // TODO: deprecate this in favor of aten::getelem OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::StringIndex(str string, int index) -> str"), + TORCH_SELECTIVE_SCHEMA( + "prim::StringIndex(str string, int index) -> str"), [](Stack* stack) { auto index = pop(stack).toInt(); auto string = pop(stack).toStringRef(); @@ -2582,7 +2587,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::chr(int i) -> str"), + TORCH_SELECTIVE_SCHEMA("aten::chr(int i) -> str"), [](Stack* stack) { auto i = pop(stack).toInt(); std::stringstream ss; @@ -2600,7 +2605,7 @@ RegisterOperators reg2({ DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::modf(float a) -> (float, float)"), + TORCH_SELECTIVE_SCHEMA("aten::modf(float a) -> (float, float)"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; @@ -2612,7 +2617,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::frexp(float a) -> (float, int)"), + TORCH_SELECTIVE_SCHEMA("aten::frexp(float a) -> (float, int)"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; @@ -2626,7 +2631,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::ldexp(float x, int i) -> float"), + TORCH_SELECTIVE_SCHEMA("aten::ldexp(float x, int i) -> float"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; @@ -2726,7 +2731,7 @@ RegisterOperators reg2({ float, float), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::abs(Tensor x) -> Tensor"), + TORCH_SELECTIVE_SCHEMA("prim::abs(Tensor x) -> Tensor"), [](Stack* stack) { at::Tensor x; pop(stack, x); @@ -2749,7 +2754,7 @@ RegisterOperators reg2({ std::copysign(a, b), float), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::_tensor_to_list(Tensor self) -> int[]"), + TORCH_SELECTIVE_SCHEMA("aten::_tensor_to_list(Tensor self) -> int[]"), [](Stack* stack) { at::Tensor t; pop(stack, t); @@ -2762,7 +2767,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::_list_to_tensor(int[] self) -> Tensor"), + TORCH_SELECTIVE_SCHEMA("aten::_list_to_tensor(int[] self) -> Tensor"), [](Stack* stack) { c10::List l = pop(stack).toIntList(); auto t = torch::empty( @@ -2774,7 +2779,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sum.int(int[] self) -> int"), + TORCH_SELECTIVE_SCHEMA("aten::sum.int(int[] self) -> int"), [](Stack* stack) { c10::List l = pop(stack).toIntList(); auto sum = 0; @@ -2785,7 +2790,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sum.float(float[] self) -> float"), + TORCH_SELECTIVE_SCHEMA("aten::sum.float(float[] self) -> float"), [](Stack* stack) { c10::List l = pop(stack).toDoubleList(); auto sum = 0.0; @@ -2796,7 +2801,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sum.complex(complex[] self) -> complex"), + TORCH_SELECTIVE_SCHEMA("aten::sum.complex(complex[] self) -> complex"), [](Stack* stack) { c10::List> l = pop(stack).toComplexDoubleList(); c10::complex sum = 0.0; @@ -2807,7 +2812,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::sum.bool(bool[] self) -> int"), + TORCH_SELECTIVE_SCHEMA("aten::sum.bool(bool[] self) -> int"), [](Stack* stack) { c10::List l = pop(stack).toBoolList(); auto sum = 0; @@ -2820,7 +2825,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::any.str(str[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::any.str(str[] self) -> bool"), [](Stack* stack) { auto l = pop(stack).toList(); for (const auto& elem : l) { @@ -2833,7 +2838,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::any.int(int[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::any.int(int[] self) -> bool"), [](Stack* stack) { c10::List l = pop(stack).toIntList(); for (const auto& elem : l) { @@ -2846,7 +2851,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::any.float(float[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::any.float(float[] self) -> bool"), [](Stack* stack) { c10::List l = pop(stack).toDoubleList(); for (const auto& elem : l) { @@ -2859,7 +2864,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::any.bool(bool[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::any.bool(bool[] self) -> bool"), [](Stack* stack) { c10::List l = pop(stack).toBoolList(); for (const auto& elem : l) { @@ -2872,7 +2877,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::all.int(int[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::all.int(int[] self) -> bool"), [](Stack* stack) { c10::List l = pop(stack).toIntList(); for (const auto& elem : l) { @@ -2885,7 +2890,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::all.float(float[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::all.float(float[] self) -> bool"), [](Stack* stack) { c10::List l = pop(stack).toDoubleList(); for (const auto& elem : l) { @@ -2898,7 +2903,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::all.bool(bool[] self) -> bool"), + TORCH_SELECTIVE_SCHEMA("aten::all.bool(bool[] self) -> bool"), [](Stack* stack) { c10::List l = pop(stack).toBoolList(); for (const auto& elem : l) { @@ -2911,7 +2916,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::divmod.int(int x, int y) -> (int, int)"), + TORCH_SELECTIVE_SCHEMA("aten::divmod.int(int x, int y) -> (int, int)"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t a, b; @@ -2933,7 +2938,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME( + TORCH_SELECTIVE_SCHEMA( "aten::divmod.float(float x, float y) -> (float, float)"), [](Stack* stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -2951,7 +2956,7 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), OperatorGenerator( - TORCH_SELECTIVE_NAME("prim::id(AnyClassType? x) -> int"), + TORCH_SELECTIVE_SCHEMA("prim::id(AnyClassType? x) -> int"), [](Stack* stack) { IValue a; pop(stack, a); @@ -2963,21 +2968,21 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), -#define DEFINE_DIVMOD_MIXED_OP(type_a, type_b) \ - OperatorGenerator( \ - TORCH_SELECTIVE_NAME("aten::divmod." #type_a "_" #type_b "(" #type_a \ - " x," #type_b " y) -> (float, float)"), \ - [](Stack* stack) { \ - type_a a; \ - type_b b; \ - pop(stack, a, b); \ - if (b == 0) { \ - throw std::runtime_error("ZeroDivisionError: float divmod()"); \ - } \ - double quot = floor(a / b); \ - double rem = a - (quot * b); \ - push(stack, quot, rem); \ - }, \ +#define DEFINE_DIVMOD_MIXED_OP(type_a, type_b) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::divmod." #type_a "_" #type_b "(" #type_a \ + " x," #type_b " y) -> (float, float)"), \ + [](Stack* stack) { \ + type_a a; \ + type_b b; \ + pop(stack, a, b); \ + if (b == 0) { \ + throw std::runtime_error("ZeroDivisionError: float divmod()"); \ + } \ + double quot = floor(a / b); \ + double rem = a - (quot * b); \ + push(stack, quot, rem); \ + }, \ aliasAnalysisFromSchema()) DEFINE_DIVMOD_MIXED_OP(int, float), @@ -2985,46 +2990,46 @@ RegisterOperators reg2({ #undef DEFINE_DIVMOD_MIXED_OP OperatorGenerator( - TORCH_SELECTIVE_NAME("aten::hash.generic(t value) -> int"), + TORCH_SELECTIVE_SCHEMA("aten::hash.generic(t value) -> int"), hashValue, aliasAnalysisFromSchema()), -#define DEFINE_COMPLEX_OP(type_a, type_b, actual_type_a, actual_type_b) \ - OperatorGenerator( \ - TORCH_SELECTIVE_NAME("aten::Complex." #type_a "_" #type_b "(" #type_a \ - " x," #type_b " y) -> complex"), \ - [](Stack* stack) { \ - actual_type_a a; \ - actual_type_b b; \ - pop(stack, a, b); \ - auto comp = c10::complex(a, b); \ - push(stack, comp); \ - }, \ +#define DEFINE_COMPLEX_OP(type_a, type_b, actual_type_a, actual_type_b) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::Complex." #type_a "_" #type_b "(" #type_a \ + " x," #type_b " y) -> complex"), \ + [](Stack* stack) { \ + actual_type_a a; \ + actual_type_b b; \ + pop(stack, a, b); \ + auto comp = c10::complex(a, b); \ + push(stack, comp); \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_COMPLEX_OP_WITH_TENSOR_ARG( \ - type_a, type_b, actual_type_a, actual_type_b) \ - OperatorGenerator( \ - TORCH_SELECTIVE_NAME("aten::Complex." #type_a "_" #type_b "(" #type_a \ - " x," #type_b " y) -> complex"), \ - [](Stack* stack) { \ - actual_type_a a; \ - actual_type_b b; \ - pop(stack, a, b); \ - auto comp = c10::complex(a.item(), b); \ - push(stack, comp); \ - }, \ - aliasAnalysisFromSchema()), \ - OperatorGenerator( \ - TORCH_SELECTIVE_NAME("aten::Complex." #type_b "_" #type_a \ - "(" #type_b " x," #type_a " y) -> complex"), \ - [](Stack* stack) { \ - actual_type_b a; \ - actual_type_a b; \ - pop(stack, a, b); \ - auto comp = c10::complex(a, b.item()); \ - push(stack, comp); \ - }, \ +#define DEFINE_COMPLEX_OP_WITH_TENSOR_ARG( \ + type_a, type_b, actual_type_a, actual_type_b) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::Complex." #type_a "_" #type_b "(" #type_a \ + " x," #type_b " y) -> complex"), \ + [](Stack* stack) { \ + actual_type_a a; \ + actual_type_b b; \ + pop(stack, a, b); \ + auto comp = c10::complex(a.item(), b); \ + push(stack, comp); \ + }, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::Complex." #type_b "_" #type_a \ + "(" #type_b " x," #type_a " y) -> complex"), \ + [](Stack* stack) { \ + actual_type_b a; \ + actual_type_a b; \ + pop(stack, a, b); \ + auto comp = c10::complex(a, b.item()); \ + push(stack, comp); \ + }, \ aliasAnalysisFromSchema()) DEFINE_COMPLEX_OP(int, bool, int, bool), diff --git a/torch/csrc/jit/runtime/script_profile.cpp b/torch/csrc/jit/runtime/script_profile.cpp new file mode 100644 index 00000000000000..3f43de788ad730 --- /dev/null +++ b/torch/csrc/jit/runtime/script_profile.cpp @@ -0,0 +1,177 @@ +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +namespace { + +class ProfilesRegistry { + public: + bool empty() { + return empty_.load(std::memory_order_relaxed); + } + + void addProfile(ScriptProfile& p) { + std::lock_guard g(mutex_); + enabledProfiles_.emplace(&p); + empty_.store(false, std::memory_order_relaxed); + } + + void removeProfile(ScriptProfile& p) { + std::lock_guard g(mutex_); + enabledProfiles_.erase(&p); + if (enabledProfiles_.empty()) { + empty_.store(true, std::memory_order_relaxed); + } + } + + void send(std::unique_ptr datapoint) { + auto shared = std::shared_ptr(std::move(datapoint)); + std::lock_guard g(mutex_); + for (auto* p : enabledProfiles_) { + p->addDatapoint(shared); + } + } + + private: + std::atomic empty_{true}; + std::mutex mutex_; + std::unordered_set enabledProfiles_; +}; + +ProfilesRegistry& getProfilesRegistry() { + static auto registry = std::ref(*new ProfilesRegistry{}); + return registry; +} + +auto initBindings() { + torch::class_("profiling", "SourceRef") + .def( + "starting_lineno", + [](const c10::intrusive_ptr& self) { + return static_cast((*self)->starting_line_no()); + }) + .def("text", [](const c10::intrusive_ptr& self) { + return (*self)->text(); + }); + + torch::class_("profiling", "InstructionStats") + .def( + "count", + [](const c10::intrusive_ptr& self) { + return self->count; + }) + .def("duration_ns", [](const c10::intrusive_ptr& self) { + return static_cast(self->duration.count()); + }); + + torch::class_("profiling", "SourceStats") + .def( + "source", + [](const c10::intrusive_ptr& self) { + return c10::make_intrusive(self->getSourceRef()); + }) + .def("line_map", &SourceStats::getLineMap); + + torch::class_("profiling", "_ScriptProfile") + .def(torch::init<>()) + .def("enable", &ScriptProfile::enable) + .def("disable", &ScriptProfile::disable) + .def("_dump_stats", [](const c10::intrusive_ptr& self) { + const auto& stats = self->dumpStats(); + c10::List> ret; + for (const auto& source : stats) { + SourceStats::LineMap lineMap; + for (const auto& line : source.second) { + lineMap.insert( + line.first, c10::make_intrusive(line.second)); + } + ret.push_back(c10::make_intrusive( + source.first, std::move(lineMap))); + } + return ret; + }); + return nullptr; +} + +const auto torchBindInitializer = initBindings(); + +} // namespace + +namespace profiling { + +InstructionSpan::InstructionSpan(Node& node) { + if (getProfilesRegistry().empty()) { + return; + } + + datapoint_ = std::make_unique(node.sourceRange()); +} + +InstructionSpan::~InstructionSpan() { + if (!datapoint_) { + return; + } + + datapoint_->end = std::chrono::steady_clock::now(); + getProfilesRegistry().send(std::move(datapoint_)); +} + +} // namespace profiling + +void ScriptProfile::enable() { + if (!std::exchange(enabled_, true)) { + getProfilesRegistry().addProfile(*this); + } +} + +void ScriptProfile::disable() { + if (std::exchange(enabled_, false)) { + getProfilesRegistry().removeProfile(*this); + } +} + +void ScriptProfile::addDatapoint( + std::shared_ptr datapoint) { + TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers."); + datapoints_.push_back(std::move(datapoint)); +} + +const ScriptProfile::SourceMap& ScriptProfile::dumpStats() { + TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats."); + + for (const auto& datapoint : datapoints_) { + if (const auto& source = datapoint->sourceRange.source()) { + if (auto fileLineCol = datapoint->sourceRange.file_line_col()) { + auto it = sourceMap_.find(*source.get()); + if (it == sourceMap_.end()) { + it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first; + } + auto& stats = it->second[std::get<1>(*fileLineCol)]; + stats.count++; + stats.duration += datapoint->end - datapoint->start; + } + } + } + datapoints_.clear(); + + return sourceMap_; +} + +ScriptProfile::~ScriptProfile() { + if (enabled_) { + getProfilesRegistry().removeProfile(*this); + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/script_profile.h b/torch/csrc/jit/runtime/script_profile.h new file mode 100644 index 00000000000000..d9b801e08d8692 --- /dev/null +++ b/torch/csrc/jit/runtime/script_profile.h @@ -0,0 +1,101 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace profiling { + +struct Datapoint { + using Timepoint = std::chrono::time_point; + SourceRange sourceRange; + Timepoint start; + Timepoint end; + + explicit Datapoint(SourceRange sr) + : sourceRange(std::move(sr)), start(std::chrono::steady_clock::now()) {} +}; + +class TORCH_API InstructionSpan { + public: + explicit InstructionSpan(Node&); + ~InstructionSpan(); + InstructionSpan(InstructionSpan&&) = delete; + InstructionSpan& operator=(InstructionSpan&&) = delete; + + private: + std::unique_ptr datapoint_; +}; + +} // namespace profiling + +struct TORCH_API InstructionStats : public CustomClassHolder { + int64_t count{0}; + std::chrono::nanoseconds duration{0}; +}; + +class TORCH_API SourceStats : public CustomClassHolder { + public: + using LineMap = c10::Dict>; + + SourceStats(SourceRef source, LineMap lineMap) + : source_(std::move(source)), lineMap_(std::move(lineMap)) {} + + const SourceRef& getSourceRef() const { + return source_; + } + + const LineMap& getLineMap() const { + return lineMap_; + } + + private: + SourceRef source_; + LineMap lineMap_; +}; + +/** + * ScriptProfile is an underlying C++ implementation for TorchScript profiling. + * The profiling section is specified by calling enable() and disable(): + * + * ... + * scriptProfile.enable(); + * ... + * (scripts) + * ... + * scriptProfile.disable(); + * ... + * + * To retrieve collected runtime data, users may call dumpStats() and do + * arbitrary filtering on the data they want. Note that dumpStats() should + * not be called inside a profiling section. + * In general, stats are aggregated per source function body, and then by line + * number. + */ +class TORCH_API ScriptProfile : public CustomClassHolder { + // Aggregates datapoints by function source id, then by line number. + using LineMap = std::map; + using SourceMap = std::map>; + + public: + void enable(); + void disable(); + const SourceMap& dumpStats(); + void addDatapoint(std::shared_ptr); + ~ScriptProfile() override; + + private: + bool enabled_{false}; + std::vector> datapoints_; + SourceMap sourceMap_; +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 07e7da743624e9..d4965ed4470db3 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -51,8 +51,7 @@ void CheckGraphEligibility(const std::shared_ptr& graph) { } } // check output types - // Static Runtime supports output types include None, Tensor, List/Tuple - // of Tensor, or Dict + // Static Runtime doesn't support complex outputs such as List of Lists for (Value* output : graph->outputs()) { VLOG(1) << "output: %" << output->debugName() << " has type: " << output->type()->repr_str(); @@ -461,13 +460,11 @@ GenerateSameStorageValues( return same_storage_values; } -} // namespace - void PrepareGraphForStaticModule( std::shared_ptr graph, const StaticModuleOptions& opts) { - OptimizeGraph(graph, opts); CheckGraphEligibility(graph); + OptimizeGraph(graph, opts); RemoveSelfFromGraphInput(graph); } @@ -502,6 +499,8 @@ PrepareForStaticModule( return std::make_pair(graph, c10::nullopt); } +} // namespace + StaticModule::StaticModule( std::shared_ptr g, const StaticModuleOptions& opts) @@ -522,17 +521,16 @@ StaticModule::StaticModule( schema_(std::move(graph_and_schema.second)) { // check opt flags if (opts.optimize_graph_output_memory) { - if (!(opts_.optimize_memory && opts_.enable_out_variant)) { - throw std::runtime_error( - "When optimize_graph_output_memory is true, optimize_memory and enable_out_variant must be set to true"); - } + TORCH_CHECK( + opts_.enable_out_variant && opts_.optimize_memory, + "When optimize_graph_output_memory is true, enable_out_variant and optimize_memory must be set to true"); } if (opts_.optimize_memory) { - if (!opts_.enable_out_variant) { - throw std::runtime_error( - "When optimize_memory is true, enable_out_variant must be set to true"); - } + TORCH_CHECK( + opts_.enable_out_variant, + "When optimize_memory is true, enable_out_variant must be set to true"); } + // map Value* to IValue (from inputs or prim::Constant) or null std::unordered_map value_to_ivalue; // map Value* to its SSA definition IR @@ -598,18 +596,12 @@ StaticModule::StaticModule( output_ssa_defs_.emplace_back(value_to_ssa_def[output]); } + // Prepare for memory planning AliasDb alias_db(graph_); auto lm = GetLivenessInformation(graph_, alias_db); external_values_ = lm.second; if (opts_.optimize_memory) { auto values = GetMemoryPlanningCandidates(graph_); - // Note (penguin): since it does not make sense to have optimize_memory - // enabled but enable_out_variant disabled, we check the flag dependence - // during initialization of StaticModule so that the following condition - // would not be true. This would make the code easier to understand - // if (!opts_.enable_out_variant) { - // values.first = {}; - // } value_to_same_storage_values_ = GenerateSameStorageValues(lm, values, alias_db); } @@ -657,10 +649,8 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) { // NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (auto i = 0; i < n.inputs().size(); ++i) { if (n.inputs()[i] == nullptr) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int node_idx; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int out_idx; + int node_idx = 0; + int out_idx = 0; std::tie(node_idx, out_idx) = sm.index_map().at(idx)[i]; DCHECK(out_idx >= 0); // input @@ -669,6 +659,7 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) { } else if (node_idx == StaticModule::CONSTANT_VALUE) { n.set_input(i, &sm.constants()[out_idx]); } else { + DCHECK(node_idx >= 0); n.set_input(i, &(nodes_[node_idx].Output(out_idx))); } } @@ -676,10 +667,8 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) { } for (const auto& index_pair : sm.output_indices()) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int node_idx; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int out_idx; + int node_idx = 0; + int out_idx = 0; std::tie(node_idx, out_idx) = index_pair; if (node_idx == StaticModule::INPUT_VALUE) { outputs_.emplace_back(&inputs_[out_idx]); @@ -690,8 +679,7 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) outputs_.emplace_back(const_cast(&sm.constants()[out_idx])); } else { - auto& n = nodes_.at(node_idx); - auto* out = &n.Output(out_idx); + auto* out = &nodes_[node_idx].Output(out_idx); outputs_.emplace_back(out); } } @@ -760,9 +748,8 @@ c10::IValue StaticRuntime::operator()( if (static_module_.opts().cleanup_activations) { // MemoryPlanner is created after the first invocation of `run()`. This is - // done intentionally because MemoryPlanner uses `TensorStorageImpl` - // object sizes of the previous `run()` for memory planning of subsequent - // runs + // done intentionally because MemoryPlanner uses `Tensor` sizes of the + // previous `run()` for memory planning of subsequent runs if (!planner_) { planner_ = std::make_unique( this, diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 3b02d6641f4b68..50a84dade7ec43 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -13,9 +13,10 @@ namespace torch { namespace jit { struct TORCH_API StaticModuleOptions { - bool cleanup_activations{true}; // to batch allocate (deallocate) tensor storage for all non-escaping // temporary tensors + bool cleanup_activations{true}; + // enabling out variant allows Static Runtime to do memory planning bool enable_out_variant{true}; // to reuse tensor storage for tensors whose live-range do not overlap to // reduce memory footprint (enable_out_variant must be true) @@ -97,8 +98,8 @@ class TORCH_API StaticModule { const StaticModuleOptions& opts); // for - // if kind == CONSTANT_KIND: map to constants_[idx] - // if kind == INPUT_KIND: map to inputs_[idx] + // if kind == CONSTANT_VALUE: map to constants_[idx] + // if kind == INPUT_VALUE: map to inputs_[idx] // otherwise: map to nodes_[kind].outputs()[idx] using DefInfo = std::pair; @@ -119,60 +120,60 @@ class TORCH_API StaticModule { size_t num_inputs() const; size_t num_outputs() const; - inline const std::unordered_map>& index_map() - const { + const std::unordered_map>& index_map() const { return node_inputs_ssa_def_map_; } - inline const std::vector& output_indices() const { + const std::vector& output_indices() const { return output_ssa_defs_; } - inline const std::vector& constants() const { + const std::vector& constants() const { return constants_; } - inline const std::vector& nodes() const { + const std::vector& nodes() const { return nodes_; } - inline const c10::optional& schema() const { + const c10::optional& schema() const { return schema_; } - inline const std::unordered_map>& + const std::unordered_map>& values_share_same_storage() const { return value_to_same_storage_values_; } - inline const std::unordered_set& external_values() const { + const std::unordered_set& external_values() const { return external_values_; } StaticRuntime& runtime(); private: - // Static runtime states StaticModuleOptions opts_; + std::shared_ptr graph_; + c10::optional schema_; std::unique_ptr cached_runtime_; + + // Bookkeeping for creating new StaticRuntime instances // IValue table (defined by prim::Constant nodes) std::vector constants_; + // The nodes we need to run + std::vector nodes_; // a vector of ssa_defs corresponding to graph->outputs() std::vector output_ssa_defs_; // map a node idx (in graph order) to a vector of ssa_defs for node inputs std::unordered_map> node_inputs_ssa_def_map_; - // The nodes we need to run - std::vector nodes_; - // map a value to the set of values that may share the same storage with it - std::unordered_map> - value_to_same_storage_values_; + + // Bookkeeping for MemoryPlanner in StaticRuntime // values whose live-time exceeds that of running one inference (e.g., input, // output, prim::Constants, and their aliases) std::unordered_set external_values_; - - // Original input - std::shared_ptr graph_; - c10::optional schema_; + // map a value to the set of values that may share the same storage with it + std::unordered_map> + value_to_same_storage_values_; }; class TORCH_API StaticRuntime { @@ -236,11 +237,11 @@ class TORCH_API StaticRuntime { return outputs_; } - inline const std::vector& nodes() const { + const std::vector& nodes() const { return nodes_; } - inline std::vector& nodes() { + std::vector& nodes() { return nodes_; } @@ -254,10 +255,10 @@ class TORCH_API StaticRuntime { // Memory planning is only enabled if sm->opts().cleanup_activations is true. // Otherwise, the memory used by activations is cached inside the static // runtime. + const StaticModule& static_module_; std::unique_ptr planner_; std::vector inputs_; std::vector outputs_; - const StaticModule& static_module_; std::vector nodes_; }; @@ -266,25 +267,26 @@ class TORCH_API StaticRuntime { /// 2. view producing op /// 3. tensor producing op (could be replaced with type 1 by adding the _out /// variant to Static Runtime) -/// The memory planner only manages tensors that are outputs of type 1 ops, -/// because type 2 ops don't incur memory allocation and for type 3, the output -/// tensors are allocated inside the operator and can't be directly managed by -/// memory planner. +/// In Static Runtime, type 2 ops are replaced with their corespoinding copy +/// versions when enable_out_variant is enabled and become type 1 ops.The memory +/// planner only manages tensors that are outputs of type 1 ops. For type 3, the +/// output tensors are allocated inside the operator and can't be directly +/// managed by memory planner. /// /// Memory planner tries to minimize the number of memory allocations by -/// tracking the unique StorageImpls of the output tensors of ops with _out -/// variants. It tries to do this in several steps: -/// 1. record the max memory usage for each StorageImpl at the end of each -/// iteration +/// tracking the output tensors of ops with _out variants with unique DataPtr +/// (part of StorageImpl). It tries to do this in several steps: +/// 1. record the max memory usage for each Tensor with unique DataPtr at the +/// end of each iteration /// 2. in the next iteration, allocate the buffer for the max total usage and /// compute the offset of each allocation with regard to the single memory -/// buffer, optionally reusing memory. In the first iteration, we rely on +/// buffer, optionally reusing memory. In the first iteration, we rely on /// the default allocator for memory allocation. /// 3. free the buffer at the end of each iteration /// Steps 1 and 3 are handled by `deallocate()`, and step 2 by `allocate()`. /// Only models with simple output types are supported, i.e. None, Tensor or -/// List/Tuple of Tensors. Complex output types such as List of Lists are not -/// supported. +/// List/Tuple/Dict of Tensors. Complex output types such as List of Lists are +/// not supported. class MemoryPlanner { public: @@ -294,9 +296,15 @@ class MemoryPlanner { const std::unordered_set& external_values, bool enable_out_variant, bool manage_graph_output_memory); + // disable copying and moving + MemoryPlanner(const MemoryPlanner&) = delete; + MemoryPlanner& operator=(const MemoryPlanner&) = delete; + MemoryPlanner(MemoryPlanner&&) = delete; + MemoryPlanner& operator=(MemoryPlanner&&) = delete; void allocate(); void deallocate(); + size_t total_managed() const { return managed_bytes_; } @@ -312,13 +320,13 @@ class MemoryPlanner { // and a vector of Tensors that should be backed by that same data. // Thus, if memonger is disabled, all vectors are of size 1. std::vector>> managed_tensors_; + at::DataPtr buffer_; // allocated each time we call Run() size_t managed_bytes_{0}; size_t reused_tensors_{0}; - at::DataPtr buffer_; // allocated each time we call Run() // since output tensors are alive after one inference, their storage // is managed differently (e.g., deallocation happens at client side) - // std::vector>> + // std::vector>> // managed_output_storage_; // size_t managed_output_bytes_{0}; // size_t reused_output_tensors_{0}; @@ -344,10 +352,6 @@ class ProcessedNode { return node_; } - inline void set_input(size_t index, const IValue* ival) { - inputs_[index] = ival; - } - // Input is readonly const IValue& Input(size_t i) const { DCHECK(i < inputs_.size()); @@ -360,6 +364,10 @@ class ProcessedNode { return outputs_[i]; } + void set_input(size_t index, const IValue* ival) { + inputs_[index] = ival; + } + const std::vector& outputs() const { return outputs_; } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 2f3f8dc238c7ec..30f8bb31efdf47 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -464,9 +465,9 @@ void optimizePointwise( int width) { using namespace torch::jit::tensorexpr; std::vector loops = ln->getLoopStmtsFor(target); - For *outer, *inner, *tail; + For *inner, *tail; TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op"); - ln->splitWithTail(loops[0], width, &outer, &inner, &tail); + ln->splitWithTail(loops[0], width, &inner, &tail); ln->vectorize(inner); } @@ -1332,9 +1333,9 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); - c10::optional rounding_mode = c10::nullopt; + c10::optional rounding_mode = c10::nullopt; if (p_node->inputs().size() > 2) { - rounding_mode = p_node->Input(2).toOptional(); + rounding_mode = p_node->Input(2).toOptional(); } if (p_node->Output(0).isNone()) { @@ -1594,5 +1595,32 @@ REGISTER_OPERATOR_FUNCTOR( }; }); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_OPERATOR_FUNCTOR( + aten::full_like, + aten_full_like, + [](Node* n) -> SROperator { + if (n->inputs().size() != 7) { + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto in1_s = p_node->Input(1).toScalar(); + if (p_node->Output(0).isNone()) { + const auto& in0_t = p_node->Input(0).toTensor(); + const auto dtype = p_node->Input(2).toOptional(); + const auto layout = p_node->Input(3).toOptional(); + const auto device = p_node->Input(4).toOptional(); + const auto pin_memory = p_node->Input(5).toOptional(); + const auto memory_format = + p_node->Input(6).toOptional(); + + p_node->Output(0) = at::native::empty_like( + in0_t, dtype, layout, device, pin_memory, memory_format); + } + auto& out_t = p_node->Output(0).toTensor(); + at::native::fill_out(out_t, in1_s); + }; + }); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index 8c6ea1d5841eb4..0480a3af2cfed0 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -5,6 +5,10 @@ namespace torch { namespace jit { +namespace { +const int64_t kInvalidSourceRangeTag = -1; +} // namespace + c10::IValue InlinedCallStackSerializer::serialize( const InlinedCallStackPtr& cs_ptr, const SourceRangeTagMap& source_range_tags) { @@ -15,13 +19,20 @@ c10::IValue InlinedCallStackSerializer::serialize( if (cs_it != serialized_inlined_callstack_.end()) { return cs_it->second; } - // Inlined callstack pointer is serialized as tuple of 3 elements - // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack)} + // Inlined callstack pointer is serialized as tuple of 4 elements + // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack), + // function name} Note function name is serialized separately because Function + // is only in memory structure. It gets constructed by JIT from serialized + // Code at runtime. As such even InlinedCallStack get constructed by JIT at + // runtime during graph inlining. However, we introduce + // serialization/deserialization of it in order to generate callstack debug + // information, _when_ equivalent InlinedCallStack cannot be constructed at + // runtime. For example, in lite interpreter or delegated backend. std::vector elements; - elements.reserve(3); + elements.reserve(4); elements.emplace_back( serialize_module_instance_info(cs_ptr->module_instance())); - int64_t source_range_tag{-1}; + int64_t source_range_tag{kInvalidSourceRangeTag}; const SourceRange& sr = cs_ptr->source_range().findSourceRangeThatGenerated() ? cs_ptr->source_range().findSourceRangeThatGenerated().value() : cs_ptr->source_range(); @@ -36,6 +47,16 @@ c10::IValue InlinedCallStackSerializer::serialize( } else { elements.emplace_back(c10::IValue()); } + if (cs_ptr->function()) { + elements.emplace_back(cs_ptr->function()->name()); + } else { + auto fn_name = cs_ptr->function_name(); + if (!fn_name.empty()) { + elements.emplace_back(fn_name); + } else { + elements.emplace_back("FunctionName_UNKNOWN"); + } + } c10::IValue serialized_cs = c10::ivalue::Tuple::create(elements); serialized_inlined_callstack_[cs_ptr] = serialized_cs; return serialized_cs; @@ -67,7 +88,7 @@ c10::IValue InlinedCallStackSerializer::serialize_module_instance_info( } std::vector CallStackDebugInfoPickler::pickle( - const std::unordered_map& callstack_ptrs, + const std::unordered_map& callstack_ptrs, const SourceRangeTagMap& source_range_tags) { std::vector ivalues; for (const auto& it : callstack_ptrs) { @@ -80,16 +101,21 @@ std::vector CallStackDebugInfoPickler::pickle( */ elements.reserve(3); elements.emplace_back(debug_handle); - int64_t source_range_tag{-1}; - const SourceRange& sr = it.second.first.findSourceRangeThatGenerated() - ? it.second.first.findSourceRangeThatGenerated().value() - : it.second.first; + int64_t source_range_tag{kInvalidSourceRangeTag}; + const auto source_range = + std::get(it.second); + const SourceRange& sr = source_range.findSourceRangeThatGenerated() + ? source_range.findSourceRangeThatGenerated().value() + : source_range; auto sr_it = source_range_tags.find(sr); if (sr_it != source_range_tags.end()) { source_range_tag = sr_it->second; } elements.emplace_back(source_range_tag); - elements.emplace_back(css_.serialize(it.second.second, source_range_tags)); + elements.emplace_back(std::get(it.second)); + const auto inlined_cs_ptr = + std::get(it.second); + elements.emplace_back(css_.serialize(inlined_cs_ptr, source_range_tags)); c10::IValue tuple = c10::ivalue::Tuple::create(elements); ivalues.emplace_back(tuple); } @@ -114,19 +140,25 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize( } auto tup_elems = tup->elements(); - TORCH_INTERNAL_ASSERT(tup_elems.size() == 3); - // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack)} + TORCH_INTERNAL_ASSERT(tup_elems.size() == 4); + // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack), + // function name} auto module_instance_info = deserialize_module_instance_info(tup_elems[0], cu); int64_t source_range_tag = tup_elems[1].toInt(); auto source_range_it = source_range_map.find(source_range_tag); TORCH_CHECK( - source_range_tag == -1 || source_range_it != source_range_map.end(), + source_range_tag == kInvalidSourceRangeTag || + source_range_it != source_range_map.end(), "Source range tag must exist in deserialized source range map." " Not found source range tag:", source_range_tag); - auto source_range = source_range_it->second; + SourceRange source_range; + if (source_range_tag != kInvalidSourceRangeTag) { + source_range = source_range_it->second; + } auto callee = deserialize(tup_elems[2], source_range_map, cu); + auto function_name = tup_elems[3].toStringRef(); InlinedCallStackPtr cs_ptr; if (callee) { cs_ptr = c10::make_intrusive( @@ -135,6 +167,7 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize( cs_ptr = c10::make_intrusive( nullptr, source_range, module_instance_info); } + cs_ptr->set_function_name(function_name); cached_inlined_callstacks_[tup] = cs_ptr; // Invoking move constructor // It is not clear if copy-ellision can happen since @@ -162,28 +195,44 @@ c10::optional InlinedCallStackDeserializer:: // type_name might be empty string "" // In that case type_ptr should be just nullptr auto type_ptr = cu->get_class(type_name); + if (!type_ptr) { + // We may have lost type information. For example in lowered backends + // original class type has no relevance. + // However, to correlate ops to their original modules + // we saved both type name and instance name. + // In such cases, when module is absorbed by lowered backend + // we augment instance name with type name instead of losing it. + auto last_dot_position = type_name.find_last_of('.'); + size_t substring_pos{0}; + if (last_dot_position != std::string::npos) { + substring_pos = last_dot_position + 1; + } + type_name = type_name.substr(substring_pos); + instance_name = instance_name + "(" + type_name + ")"; + } cached_module_instance_info_[tup] = ModuleInstanceInfo(type_ptr, instance_name); return cached_module_instance_info_[tup]; } -ska::flat_hash_map CallStackDebugInfoUnpickler:: +ska::flat_hash_map CallStackDebugInfoUnpickler:: unpickle( at::DataPtr&& data, size_t size, const ska::flat_hash_map& source_range_map, const std::shared_ptr& cu) { auto ival = jit::unpickle(reinterpret_cast(data.get()), size); - ska::flat_hash_map callstack_ptrs; + ska::flat_hash_map callstack_ptrs; auto& ivalues = ival.toTuple()->elements(); for (auto& val : ivalues) { const auto tup_elems = val.toTuple()->elements(); TORCH_CHECK( - tup_elems.size() == 3, - "Pickled map must have three elements: " - "debug_handle, source_range_tag, IValue(inlined_call_stack)"); + tup_elems.size() == 4, + "Pickled map must have four elements: " + "debug_handle, source_range_tag, op name, IValue(inlined_call_stack)"); int64_t debug_handle = tup_elems[0].toInt(); int64_t source_range_tag = tup_elems[1].toInt(); + const std::string& node_name = tup_elems[2].toStringRef(); auto source_range_it = source_range_map.find(source_range_tag); TORCH_CHECK( source_range_it != source_range_map.end(), @@ -192,8 +241,10 @@ ska::flat_hash_map CallStackDebugInfoUnpickler:: TORCH_CHECK( callstack_ptrs.count(debug_handle) == 0, "Debug handles should be unique."); - callstack_ptrs[debug_handle] = std::make_pair( - source_range, csds_.deserialize(tup_elems[2], source_range_map, cu)); + callstack_ptrs[debug_handle] = std::make_tuple( + source_range, + node_name, + csds_.deserialize(tup_elems[3], source_range_map, cu)); } return callstack_ptrs; } diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.h b/torch/csrc/jit/serialization/callstack_debug_info_serialization.h index 219b0713ca00d7..ac1bdf8d3b1d84 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.h +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.h @@ -49,7 +49,7 @@ class TORCH_API CallStackDebugInfoPickler { CallStackDebugInfoPickler() = default; std::vector pickle( - const std::unordered_map& callstack_ptrs, + const std::unordered_map& callstack_ptrs, const SourceRangeTagMap& source_range_tags); private: @@ -77,7 +77,7 @@ class InlinedCallStackDeserializer { class TORCH_API CallStackDebugInfoUnpickler { public: - ska::flat_hash_map unpickle( + ska::flat_hash_map unpickle( at::DataPtr&& data, size_t size, const ska::flat_hash_map& source_range_map, diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index c2a6fd729481c8..65b7cf467ef9a8 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -64,7 +65,7 @@ static IValue Table( std::pair getFunctionTuple( const Module& module, const Function& func, - BackendDebugHandleManager& debug_handle_manager) { + BackendDebugInfoRecorder& debug_info_recorder) { auto graph = func.graph()->copy(); Inline(*graph); @@ -145,8 +146,7 @@ std::pair getFunctionTuple( " is not supported in mobile module."); } auto node = code->instructions_source()[i]; - int64_t debug_handle = - debug_handle_manager.getNextDebugHandleForInlinedCallStackPtr(node); + int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node); // Note 1-to-1 correspondence between instructions and debug handles op_debug_handles.emplace_back(debug_handle); } @@ -275,7 +275,7 @@ void setstateTuple( std::vector& elements, std::unordered_set& qn_cache, std::vector& debug_info_elements, - BackendDebugHandleManager& debug_handle_manager) { + BackendDebugInfoRecorder& debug_info_recorder) { if (!ivalue.isObject()) return; auto obj = ivalue.toObject(); @@ -287,8 +287,7 @@ void setstateTuple( return; } if (setstate.isGraphFunction()) { - auto func_tuple = - getFunctionTuple(module, setstate, debug_handle_manager); + auto func_tuple = getFunctionTuple(module, setstate, debug_info_recorder); elements.push_back(func_tuple.first); qn_cache.emplace(qn); debug_info_elements.push_back(func_tuple.second); @@ -301,17 +300,82 @@ void setstateTuple( elements, qn_cache, debug_info_elements, - debug_handle_manager); + debug_info_recorder); } } } + +// Check if the global static map of backend debug info +// contains debug info for this module and any of its children. +// If so combine all the maps together and return one. +void getBackendDebugInfoMap( + const Module& m, + BackendDebugInfoMapType& debug_map) { + c10::QualifiedName type_name; + if (m.type()->name()) { + type_name = m.type()->name().value(); + } + if (c10::string_view(type_name.name()).ends_with("LoweredModule")) { + auto backend_debug_info = + m.attr("__backend_debug_info").toCustomClass(); + const auto& map = backend_debug_info->getDebugInfoMap(); + if (map) { + debug_map.insert(map.value().begin(), map.value().end()); + } + } + for (const auto& c : m.children()) { + getBackendDebugInfoMap(c, debug_map); + } +} + +SourceRangeRecords getBackendSourceRanges(const Module& m) { + SourceRangeRecords sr_records; + c10::QualifiedName type_name; + if (m.type()->name()) { + type_name = m.type()->name().value(); + } + if (c10::string_view(type_name.name()).ends_with("LoweredModule")) { + constexpr size_t kSourceRange = 1; + auto backend_debug_info = + m.attr("__backend_debug_info").toCustomClass(); + const auto& map = backend_debug_info->getDebugInfoMap(); + if (map) { + const auto& map_val = map.value(); + // This map is map of debug handle-to-DebugInfoTuple + // DebugInfoTuple= + for (const auto& it : map_val) { + auto& source_range = + std::get(it.second); + sr_records.emplace_back( + std::numeric_limits::max(), source_range); + auto cs_ptr = std::get(it.second); + if (cs_ptr) { + for (const auto& e : cs_ptr->vec()) { + const auto sr = std::get(e); + sr_records.emplace_back(std::numeric_limits::max(), sr); + } + } + } + } + } + for (const auto& c : m.children()) { + const auto& child_sr_records = getBackendSourceRanges(c); + sr_records.reserve(sr_records.size() + child_sr_records.size()); + std::move( + child_sr_records.begin(), + child_sr_records.end(), + std::back_inserter(sr_records)); + } + return sr_records; +} + } // namespace void moduleMethodsTuple( const Module& module, std::vector& elements, // note: appended to in-place std::vector& debug_info_elements, - BackendDebugHandleManager& debug_handle_manager) { + BackendDebugInfoRecorder& debug_info_recorder) { auto methods = module.get_methods(); std::unordered_set qn_cache; // top level methods @@ -321,7 +385,7 @@ void moduleMethodsTuple( continue; } auto func_tuple = - getFunctionTuple(module, method.function(), debug_handle_manager); + getFunctionTuple(module, method.function(), debug_info_recorder); elements.push_back(func_tuple.first); qn_cache.emplace(qn); debug_info_elements.push_back(func_tuple.second); @@ -334,7 +398,7 @@ void moduleMethodsTuple( elements, qn_cache, debug_info_elements, - debug_handle_manager); + debug_info_recorder); } void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) { @@ -564,7 +628,7 @@ void ScriptModuleSerializer::writeByteCode( const Module& module, const bool save_mobile_debug_info) { std::vector elements; - BackendDebugHandleManager debug_handle_manager; + BackendDebugInfoRecorder debug_info_recorder; elements.emplace_back( static_cast(caffe2::serialize::kProducedBytecodeVersion)); std::vector debug_info_elements; @@ -573,7 +637,7 @@ void ScriptModuleSerializer::writeByteCode( static_cast(caffe2::serialize::kProducedBytecodeVersion)); moduleMethodsTuple( - module, elements, debug_info_elements, debug_handle_manager); + module, elements, debug_info_elements, debug_info_recorder); auto telements = Tup(std::move(elements)); writeArchive( telements, @@ -605,15 +669,41 @@ void ScriptModuleSerializer::writeByteCode( /*archive_name=*/"mobile_debug_handles", /*archive_dir=*/"", /*tensor_dir=*/"mobile_debug_handles/"); + static constexpr size_t kMinToCompress = 200; + // For delegated backends get source ranges that are in the debug info + // map. Since delegated backend replace original module with lowered + // module we will not serialize original module's code which is what would + // have contained source range. Since we dont have that anymore, extract + // source ranges out of delegated module and store in a separate archive. + // Note that we must do this first because in order to serialize inlined + // CS appropriate source_range_tags must have been generated. + auto backend_source_range_records = getBackendSourceRanges(module); + SourceRangePickler source_range_pickler; + updateSourceRangeTags(backend_source_range_records); + auto range_data = source_range_pickler.pickle( + backend_source_range_records, source_range_tags_); + std::string debugFilename = "delegated_backends.debug_pkl"; + writer_.writeRecord( + debugFilename, + range_data.data(), + range_data.size(), + range_data.size() > kMinToCompress /*compress*/); + + // For delegated backends get debug_info_map + // This is merged with other debug_info_map of other modules + // which were not delegated. + BackendDebugInfoMapType backend_debug_info_map; + getBackendDebugInfoMap(module, backend_debug_info_map); // Now get the debug-handles-to-inlined-cs-ptr-map // And serialize that in a separate archive - auto debug_handle_cs_ptr_map = debug_handle_manager.getCallStackPtrMap(); + auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording(); + debug_handle_cs_ptr_map.insert( + backend_debug_info_map.begin(), backend_debug_info_map.end()); CallStackDebugInfoPickler cs_debug_info_pickler; auto cs_data = cs_debug_info_pickler.pickle( debug_handle_cs_ptr_map, source_range_tags_); // Write out map: [debug-handle, {source range, InlinedCallStack}] std::string filename = "callstack_debug_map.pkl"; - static constexpr size_t kMinToCompress = 200; writer_.writeRecord( filename, cs_data.data(), @@ -726,7 +816,7 @@ namespace { void export_opnames(const script::Module& m, std::set& opnames) { std::vector elements; std::vector debug_info_elements; - BackendDebugHandleManager dummy; + BackendDebugInfoRecorder dummy; moduleMethodsTuple(m, elements, debug_info_elements, dummy); for (const auto& element : elements) { auto table = element.toTuple()->elements()[1]; diff --git a/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp b/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp index eb83ed11a89b22..2446cacf587193 100644 --- a/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp @@ -1233,6 +1233,23 @@ void nnc_aten_silu( } catch (...) { } } +void nnc_aten_mish( + int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) { + std::vector tensors = + constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes); + at::Tensor& r = tensors[0]; + const at::Tensor& self = tensors[1]; + try { + at::mish_out(r, self); + } catch (...) { + } +} void nnc_aten_sigmoid( int64_t bufs_num, void** buf_data, @@ -2913,6 +2930,9 @@ const static RegisterNNCExternalFunction nnc_rsqrt( const static RegisterNNCExternalFunction nnc_silu( "nnc_aten_silu", nnc_aten_silu); +const static RegisterNNCExternalFunction nnc_mish( + "nnc_aten_mish", + nnc_aten_mish); const static RegisterNNCExternalFunction nnc_sigmoid( "nnc_aten_sigmoid", nnc_aten_sigmoid); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 1fe0844a6b9341..2284fc018c4025 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1466,6 +1466,8 @@ Tensor* computeMatmul( auto size_a = a.dims(); auto size_b = b.dims(); + // We currently only support rank 2 matmuls + TORCH_INTERNAL_ASSERT(size_a.size() == 2 && size_b.size() == 2); auto total_size = dynamic_cast( IRSimplifier::simplify( cast(size_a[0]) * cast(size_a[1]) * @@ -1852,6 +1854,21 @@ Tensor* tensorexpr::computeOperandValue( }); } break; + case aten::leaky_relu: { + return computeTwoOperand( + "aten_leaky_relu", + inputs, + outputShape, + outputType, + [](const ExprHandle& a, const ExprHandle& negative_slope) { + auto neg_slope = Cast::make(a.dtype(), negative_slope); + auto zero = Cast::make(a.dtype(), 0); + auto one = Cast::make(a.dtype(), 1); + auto cs = CompareSelect::make(a, zero, one, neg_slope, kGT); + return a * cs; + }); + } break; + case aten::gelu: { return computeOneOperand( "aten_gelu", @@ -2407,7 +2424,8 @@ Tensor* tensorexpr::computeOperandValue( "aten_slice", c10::fmap(outputShape), [&](const std::vector& axes) { - int64_t dim = c10::get(inputs[1]); + int64_t dim = + at::maybe_wrap_dim(c10::get(inputs[1]), axes.size()); ExprHandle start = constant(inputs[2]); ExprHandle stride = constant(inputs[4]); @@ -2477,7 +2495,8 @@ Tensor* tensorexpr::computeOperandValue( std::vector new_axes; assert(permute_dims.size() == axes.size()); for (auto i : permute_dims) { - new_axes.push_back(axes[i]); + auto new_dim = at::maybe_wrap_dim(i, A.ndim()); + new_axes.push_back(axes[new_dim]); } return A.load(new_axes); }); @@ -2561,6 +2580,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::neg: case aten::isnan: case aten::relu: + case aten::leaky_relu: case aten::hardswish: case aten::gelu: case aten::batch_norm: @@ -2779,35 +2799,29 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { int blockSize = getTECudaPointwiseBlockSize(); if (loopLevels == 2) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; const int kDefaultBlockSize = 512; if (blockSize < 0) { blockSize = kDefaultBlockSize; } - l.splitWithMask(flattened, blockSize, &outer, &inner); - l.setGPUBlockIndex(outer, 0); + l.splitWithMask(flattened, blockSize, &inner); + l.setGPUBlockIndex(flattened, 0); l.setGPUThreadIndex(inner, 0); } else if (loopLevels == 3) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* inner1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* inner2; // TODO: change the number of microprocessors const int kDefaultBlockCount = 1280; const int kDefaultBlockSize = 256; blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount; blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize; - l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner); - l.splitWithMask(inner, blockSize, &inner1, &inner2); - l.setGPUBlockIndex(inner1, 0); - l.setGPUThreadIndex(inner2, 0); + l.splitWithMask(flattened, blockCount * blockSize, &inner); + l.splitWithMask(inner, blockSize, &inner1); + l.setGPUBlockIndex(inner, 0); + l.setGPUThreadIndex(inner1, 0); } else { throw std::runtime_error( "Invalid loop-level: " + c10::to_string(loopLevels)); @@ -2830,12 +2844,11 @@ Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { LoopNest::flatten(loops, &flattened); assert(flattened); - For* outer = nullptr; For* inner = nullptr; - l.splitWithMask(flattened, blockSize, &outer, &inner); - l.setGPUBlockIndex(outer, 0); + l.splitWithMask(flattened, blockSize, &inner); + l.setGPUBlockIndex(flattened, 0); l.setGPUThreadIndex(inner, 0); - l.setBufferMap(outer, block_analysis->getBufferMap()); + l.setBufferMap(flattened, block_analysis->getBufferMap()); } } @@ -2988,7 +3001,7 @@ Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) { break; } default: { - throw unsupported_dtype(); + throw unsupported_dtype(t->repr_str()); break; } } diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index f16fa816ba76e0..7cdad587796802 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1211,26 +1211,22 @@ void LoopNest::vectorizeInnerLoops() { // vectorize inner loops. for (For* loop : innerLoops) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer1; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* split1; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* tail1; static const int kBodyVectorWidth = 8; - splitWithTail(loop, kBodyVectorWidth, &outer1, &split1, &tail1); + splitWithTail(loop, kBodyVectorWidth, &split1, &tail1); vectorize(split1); if (tail1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For* outer2; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* split2; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) For* tail2; static const int kTailVectorWidth = 4; - splitWithTail(tail1, kTailVectorWidth, &outer2, &split2, &tail2); + splitWithTail(tail1, kTailVectorWidth, &split2, &tail2); vectorize(split2); } } @@ -1326,16 +1322,11 @@ void LoopNest::sliceTail(For* f, int factor) { void LoopNest::splitWithTail(For* f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *outer, *inner, *tail; - splitWithTail(f, factor, &outer, &inner, &tail); + For *inner, *tail; + splitWithTail(f, factor, &inner, &tail); } -void LoopNest::splitWithTail( - For* f, - int factor, - For** outer, - For** inner, - For** tail) { +void LoopNest::splitWithTail(For* f, int factor, For** inner, For** tail) { if (!f) { throw malformed_input("splitWithTail attempted on null loop", f); } @@ -1371,16 +1362,6 @@ void LoopNest::splitWithTail( // x -> x.outer * inner.size + x.inner const Expr* combined_index1 = new Add(new Mul(i_outer, factor_expr), i_inner); - Stmt* body_inner = - Substitute(Stmt::clone(f->body()), {{f->var(), combined_index1}}); - - *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); - *outer = - new For(i_outer, new IntImm(0), split_count, *inner, f->loop_options()); - - // TODO: cleanup API for adding/removing statements - p->replace_stmt(f, *outer); - if (tail_is_needed) { const Var* i_tail = new Var(loop_var_name + "_tail", loop_var_dtype); // x -> x.tail + outer.size * inner.size @@ -1391,19 +1372,28 @@ void LoopNest::splitWithTail( Substitute(Stmt::clone(f->body()), {{f->var(), combined_index2}}); *tail = new For(i_tail, new IntImm(0), tail_size, body_tail); - p->insert_stmt_after(*tail, *outer); + p->insert_stmt_after(*tail, f); } else { *tail = nullptr; } + + Stmt* body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}}); + + *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); + // The input loop `f` will be the outer loop after split. + f->setVar(i_outer); + f->setStart(new IntImm(0)); + f->setStop(split_count); + f->setBody(*inner); } void LoopNest::splitWithMask(For* f, int factor) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - For *outer, *inner; - splitWithMask(f, factor, &outer, &inner); + For* inner; + splitWithMask(f, factor, &inner); } -void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { +void LoopNest::splitWithMask(For* f, int factor, For** inner) { Block* p = dynamic_cast(f->get_parent()); if (!p) { std::cerr << "Parent is not a Block!\n"; @@ -1438,7 +1428,7 @@ void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { // x -> x.outer * inner.size + x.inner const Expr* combined_index = new Add(new Mul(i_outer, factor_expr), i_inner); - Stmt* body_inner = Stmt::clone(f->body()); + Stmt* body_inner = f->removeBody(); // TODO: is it ok that we're doing it eagerly? In the other implementation we // are only materializing predicates at the last, lowering, step. if (tail_is_needed) { @@ -1455,10 +1445,11 @@ void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { body_inner = Substitute(body_inner, {{f->var(), combined_index}}); *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner); - *outer = - new For(i_outer, new IntImm(0), split_count, *inner, f->loop_options()); - - p->replace_stmt(f, *outer); + // The input loop `f` will be the outer loop after split. + f->setVar(i_outer); + f->setStart(new IntImm(0)); + f->setStop(split_count); + f->setBody(*inner); } std::vector LoopNest::distributeLoop( diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 0bf2a5dd14530c..bb5db759162e93 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -113,16 +113,62 @@ class TORCH_API LoopNest { // TODO: Remove this constraint. bool optimizeConditionals(); + // Splits the given loop into 2 nested loops with the given factor as the + // inner loop bound. If the factor does not evenly divide the loop bound, + // then the remainining iterations are extracted into a tail loop that is + // added after the given loop. + // + // For example, consider the following code: + // for (int i = 0; i < 100; ++i) { + // A[i] = + // } + // + // splitWithTail(i, 8, ...) will result in: + // for (int i_outer = 0; i_outer < 12; ++i_outer) { + // for (int i_inner = 0; i_inner < 8; ++i_inner) { + // A[i_outer * 8 + i_inner] = + // } + // } + // for (int i_tail = 0; i_tail < 4; ++i_tail) { + // A[i_tail + 96] = + // } + // + // The given loop will be transformed to the outer loop after splitting. + // So, the pointer to the input loop should be valid after splitting and + // will point to the outer loop. The `inner` and `tail` parameters will be + // set to point to the inner and tail loops that are generated. + static void splitWithTail(For* f, int factor, For** inner, For** tail); + // A convenience wrapper when the caller does not need to access the + // split loops. static void splitWithTail(For* f, int factor); - static void splitWithTail( - For* f, - int factor, - For** outer, - For** inner, - For** tail); + // Splits the given loop into 2 nested loops with the given factor as the + // inner loop bound. If the factor does not evenly divide the loop bound, + // then a conditional is inserted into the body to handle the remaining + // iterations appropriately. + // + // For example, consider the following code: + // for (int i = 0; i < 100; ++i) { + // A[i] = + // } + // + // splitWithMask(i, 8, ...) will result in: + // for (int i_outer = 0; i_outer < 13; ++i_outer) { + // for (int i_inner = 0; i_inner < 8; ++i_inner) { + // if (i_outer * 8 + i_inner < 100) { + // A[i_outer * 8 + i_inner] = + // } + // } + // } + // + // The given loop will be transformed to the outer loop after splitting. + // So, the pointer to the input loop should be valid after splitting and + // will point to the outer loop. The `inner` parameter will be set to point + // to the inner loop that is generated. + static void splitWithMask(For* f, int factor, For** inner); + // A convenience wrapper when the caller does not need to access the + // split loops. static void splitWithMask(For* f, int factor); - static void splitWithMask(For* f, int factor, For** outer, For** inner); // The following methods support loop distribution. // For example, consider the following code. This will be used to diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index 593f07a83618c9..5a6c2460fce53f 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -413,17 +413,17 @@ void initTensorExprBindings(PyObject* module) { .def( "split_with_tail", [](const LoopNest& self, For* f, int factor) { - For *outer = nullptr, *inner = nullptr, *tail = nullptr; - self.splitWithTail(f, factor, &outer, &inner, &tail); - return std::make_tuple(outer, inner, tail); + For *inner = nullptr, *tail = nullptr; + self.splitWithTail(f, factor, &inner, &tail); + return std::make_tuple(inner, tail); }, py::return_value_policy::reference) .def( "split_with_mask", [](const LoopNest& self, For* f, int factor) { - For *outer = nullptr, *inner = nullptr; - self.splitWithMask(f, factor, &outer, &inner); - return std::make_tuple(outer, inner); + For* inner = nullptr; + self.splitWithMask(f, factor, &inner); + return inner; }, py::return_value_policy::reference) .def( diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 7e4b24b9a83277..6d3031b7588ac3 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -39,6 +39,7 @@ static std::unordered_map type_map = { {"Device", ParameterType::DEVICE}, {"Stream", ParameterType::STREAM}, {"std::string", ParameterType::STRING}, + {"c10::string_view", ParameterType::STRING}, {"Dimname", ParameterType::DIMNAME}, {"DimnameList", ParameterType::DIMNAME_LIST}, {"ScalarList", ParameterType::SCALAR_LIST}, @@ -674,7 +675,7 @@ void FunctionParameter::set_default_str(const std::string& str) { } else if (type_ == ParameterType::SCALARTYPE) { if (str == "None") { default_scalartype = at::ScalarType::Undefined; - } else if (str == "torch.int64") { + } else if (str == "torch.int64" || str == "at::kLong") { default_scalartype = at::ScalarType::Long; } else { throw std::runtime_error("invalid default value for ScalarType: " + str); diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index c9d1628e48b742..abd0c167e1bca2 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -197,6 +197,9 @@ struct PythonArgs { inline std::string string(int i); inline std::string stringWithDefault(int i, const std::string& default_str); inline c10::optional stringOptional(int i); + inline c10::string_view stringView(int i); + inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str); + inline c10::optional stringViewOptional(int i); inline PyObject* pyobject(int i); inline int64_t toInt64(int i); inline int64_t toInt64WithDefault(int i, int64_t default_int); @@ -592,6 +595,20 @@ inline c10::optional PythonArgs::stringOptional(int i) { return THPUtils_unpackString(args[i]); } +inline c10::string_view PythonArgs::stringView(int i) { + return stringViewWithDefault(i, signature.params[i].default_string); +} + +inline c10::string_view PythonArgs::stringViewWithDefault(int i, const c10::string_view default_str) { + if (!args[i]) return default_str; + return THPUtils_unpackStringView(args[i]); +} + +inline c10::optional PythonArgs::stringViewOptional(int i) { + if (!args[i]) return c10::nullopt; + return THPUtils_unpackStringView(args[i]); +} + inline int64_t PythonArgs::toInt64(int i) { if (!args[i]) return signature.params[i].default_int; if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index e3fc5693cb795d..d373c7625f1c7d 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -35,6 +35,31 @@ inline std::string THPUtils_unpackString(PyObject* obj) { throw std::runtime_error("unpackString: expected bytes or unicode object"); } +// Unpacks PyBytes (PyString) or PyUnicode as c10::string_view +// PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. +// NOTE: If `obj` is destroyed, then the non-owning c10::string_view will +// become invalid. If the string needs to be accessed at any point after +// `obj` is destroyed, then the c10::string_view should be copied into +// a std::string, or another owning object, and kept alive. For an example, +// look at how IValue and autograd nodes handle c10::string_view arguments. +// NOTE: this method requires the GIL +inline c10::string_view THPUtils_unpackStringView(PyObject* obj) { + if (PyBytes_Check(obj)) { + size_t size = PyBytes_GET_SIZE(obj); + return c10::string_view(PyBytes_AS_STRING(obj), size); + } + if (PyUnicode_Check(obj)) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + Py_ssize_t size; + const char* data = PyUnicode_AsUTF8AndSize(obj, &size); + if (!data) { + throw std::runtime_error("error unpacking string as utf-8"); + } + return c10::string_view(data, (size_t)size); + } + throw std::runtime_error("unpackString: expected bytes or unicode object"); +} + inline PyObject* THPUtils_packString(const char* str) { return PyUnicode_FromString(str); } diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 82789c147bfa94..3cb44e6d1ad97f 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -88,7 +88,7 @@ Tensor new_with_tensor(c10::TensorOptions options, at::ScalarType scalar_type, c options = options.dtype(scalar_type); TORCH_CHECK_TYPE(other.options().type_equal(options), "expected ", options, " (got ", other.options(), ")"); - return other.slice(); + return other.alias(); } std::vector compute_sizes(PyObject* seq) { diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 61574877ac1619..85e5f57d78fdd6 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -531,3 +531,21 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: mem = p.usedGpuMemory / (1024 * 1024) lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") return "\n".join(lines) + +def mem_get_info(device: Union[Device, int] = None) -> int: + r"""Returns the global free and total GPU memory occupied for a given + device using cudaMemGetInfo. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + return torch.cuda.cudart().cudaMemGetInfo(device) diff --git a/torch/custom_class.h b/torch/custom_class.h index f5b5b07b729120..cbbbae1a38697f 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -18,38 +18,6 @@ namespace torch { -/// This struct is used to represent default values for arguments -/// when registering methods for custom classes. -/// static auto register_foo = torch::class_("myclasses", "Foo") -/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); -struct arg { - // Static method for representing a default value of None. This is meant to - // be used like so: - // torch::arg("name") = torch::arg::none - // and is identical to: - // torch::arg("name") = IValue() - static c10::IValue none() { - return c10::IValue(); - } - - // Explicit constructor. - explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {} - // Assignment operator. This enables the pybind-like syntax of - // torch::arg("name") = value. - arg& operator=(const c10::IValue& rhs) { - value_ = rhs; - return *this; - } - - // The name of the argument. This is copied to the schema; argument - // names cannot be extracted from the C++ declaration. - std::string name_; - // IValue's default constructor makes it None, which is not distinguishable from - // an actual, user-provided default value that is None. This boolean - // helps distinguish between the two cases. - c10::optional value_; -}; - /// This function is used in conjunction with `class_::def()` to register /// a constructor for a given C++ class type. For example, /// `torch::init()` would register a two-argument constructor @@ -93,7 +61,7 @@ decltype(auto) init(Func&& f) { /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()` /// is registered with a C++ lambda expression. template -class class_ { +class class_ : public ::torch::detail::class_base { static_assert(std::is_base_of::value, "torch::class_ requires T to inherit from CustomClassHolder"); @@ -105,25 +73,8 @@ class class_ { /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript - explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") { - detail::checkValidIdent(namespaceName, "Namespace name"); - detail::checkValidIdent(className, "Class name"); - qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className; - - classTypePtr = at::ClassType::create( - c10::QualifiedName(qualClassName), - std::weak_ptr(), - /*is_module=*/false, - std::move(doc_string)); - classTypePtr->addAttribute("capsule", at::CapsuleType::get()); - - c10::getCustomClassTypeMap().insert( - {std::type_index(typeid(c10::intrusive_ptr)), classTypePtr}); - c10::getCustomClassTypeMap().insert( - {std::type_index(typeid(c10::tagged_capsule)), classTypePtr}); - - registerCustomClass(classTypePtr); - } + explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") + : class_base(namespaceName, className, std::move(doc_string), typeid(c10::intrusive_ptr), typeid(c10::tagged_capsule)) {} /// def() can be used in conjunction with `torch::init()` to register /// a constructor for a given C++ class type. For example, passing @@ -419,31 +370,15 @@ class class_ { // extracted by inferFunctionSchemaSingleReturn, and so there must be a // torch::arg instance in default_args even for arguments that do not // have an actual default value provided. - TORCH_CHECK( - default_args.size() == 0 || - default_args.size() == schema.arguments().size() - 1, - "Default values must be specified for none or all arguments"); + TORCH_CHECK( + default_args.size() == 0 || + default_args.size() == schema.arguments().size() - 1, + "Default values must be specified for none or all arguments"); // If there are default args, copy the argument names and default values to the // function schema. if (default_args.size() > 0) { - const auto& old_args = schema.arguments(); - std::vector new_args; - new_args.reserve(old_args.size()); - std::vector default_args_v(default_args); - - new_args.emplace_back(old_args[0]); - for (size_t i = 0; i < default_args_v.size(); ++i) { - // Skip self. - auto& arg = old_args[i+1]; - new_args.emplace_back(c10::Argument( - std::move(default_args_v[i].name_), - arg.type(), - arg.N(), - default_args_v[i].value_.has_value() ? std::move(*default_args_v[i].value_) : c10::nullopt)); - } - - schema = schema.cloneWithArguments(new_args); + schema = withNewArguments(schema, default_args); } auto wrapped_func = @@ -467,9 +402,6 @@ class class_ { registerCustomClassMethod(std::move(method)); return method_val; } - - std::string qualClassName; - at::ClassTypePtr classTypePtr; }; /// make_custom_class() is a convenient way to create an instance of a registered diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h index 4d5ed3f3556c88..6984d9f0996226 100644 --- a/torch/custom_class_detail.h +++ b/torch/custom_class_detail.h @@ -7,6 +7,38 @@ namespace torch { +/// This struct is used to represent default values for arguments +/// when registering methods for custom classes. +/// static auto register_foo = torch::class_("myclasses", "Foo") +/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); +struct arg { + // Static method for representing a default value of None. This is meant to + // be used like so: + // torch::arg("name") = torch::arg::none + // and is identical to: + // torch::arg("name") = IValue() + static c10::IValue none() { + return c10::IValue(); + } + + // Explicit constructor. + explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {} + // Assignment operator. This enables the pybind-like syntax of + // torch::arg("name") = value. + arg& operator=(const c10::IValue& rhs) { + value_ = rhs; + return *this; + } + + // The name of the argument. This is copied to the schema; argument + // names cannot be extracted from the C++ declaration. + std::string name_; + // IValue's default constructor makes it None, which is not distinguishable from + // an actual, user-provided default value that is None. This boolean + // helps distinguish between the two cases. + c10::optional value_; +}; + namespace detail { // Argument type utilities @@ -134,6 +166,22 @@ inline void checkValidIdent(const std::string& str, const char *type) { } } +class TORCH_API class_base { + protected: + explicit class_base( + const std::string& namespaceName, + const std::string& className, + std::string doc_string, + const std::type_info& intrusivePtrClassTypeid, + const std::type_info& taggedCapsuleClass); + + static c10::FunctionSchema withNewArguments( + const c10::FunctionSchema& schema, + std::initializer_list default_args); + std::string qualClassName; + at::ClassTypePtr classTypePtr; +}; + } // namespace detail TORCH_API void registerCustomClass(at::ClassTypePtr class_type); diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index c814c590f8426b..ddd4b84a655774 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -53,4 +53,8 @@ def is_available() -> bool: # See the comment in `distributed_c10d.py` above `_backend` on why we expose # this. - from .distributed_c10d import _backend, _all_gather_base + from .distributed_c10d import ( + _backend, + _all_gather_base, + _create_process_group_wrapper + ) diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py new file mode 100644 index 00000000000000..4164c4673d949b --- /dev/null +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -0,0 +1,50 @@ +import torch +from torch.distributed._sharding_spec import ShardingSpec +from .api import ( + ShardedTensor, +) + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None,): + """ + Creates an empty :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group) diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_sharded_tensor/api.py new file mode 100644 index 00000000000000..89125cd64d285e --- /dev/null +++ b/torch/distributed/_sharded_tensor/api.py @@ -0,0 +1,206 @@ +from typing import List + +import torch +import torch.distributed as dist +from torch.distributed._sharding_spec import ( + ChunkShardingSpec, + ShardMetadata, + ShardingSpec, +) +from torch.distributed._sharding_spec._internals import is_valid_device +from torch.distributed.utils import _parse_remote_device + + +class Shard(object): + """ + Container which holds the data for a shard as a Tensor and also + the associated metadata for that shard. + """ + __slots__ = ['_tensor', '_metadata'] + + def __init__(self, tensor: torch.Tensor, metadata: ShardMetadata): + self._tensor = tensor + self._metadata = metadata + + @property + def tensor(self) -> torch.Tensor: + return self._tensor + + @property + def metadata(self) -> ShardMetadata: + return self._metadata + + +class ShardedTensor(object): + """ + ShardedTensor is an abstraction to represent Tensors that are sharded + across multiple devices and multiple processes. + + ShardedTensor is initialized in an SPMD like fashion where each rank + initializes the ShardedTensor. The ShardedTensor object on each rank + then only stores the local shard for the Tensor and provides global + metadata for all the shards. + + ShardedTensor doesn't provide any Tensor like operations but is a wrapper + providing the Tensor representing the local shard and the global metadata. + Using these, users can build their custom distributed sharded computations + on top of this primitive. The local shards are all initialized using + :meth:`torch.empty`. + + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. If specified the ShardedTensor is only + built on ranks that are part of this process group and the provided ``sharding_spec`` + is applied in the context of this process group. + """ + + def __init__( + self, + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + ): + if layout != torch.strided: + raise ValueError('Only torch.strided layout is currently supported') + + if memory_format != torch.contiguous_format: + raise ValueError('Only torch.contiguous_format memory_format is currently supported') + + self._sharding_spec = sharding_spec + self._dims = list(size) + self._process_group = ( + process_group + if process_group is not None + else torch.distributed.distributed_c10d._get_default_group() + ) + + if torch.distributed.distributed_c10d._rank_not_in_group(self._process_group): + raise ValueError(f'Global rank: {dist.get_rank()} not part of process group') + + self._local_shards: List[Shard] = [] + self._sharding_metadata: List[ShardMetadata] = [] + if isinstance(self._sharding_spec, ChunkShardingSpec): + self._init_chunked( + self._sharding_spec, + self._dims, + dtype, + layout, + requires_grad, + pin_memory, + memory_format, + process_group, + ) + + def _init_chunked( + self, + sharding_spec: ChunkShardingSpec, + dims, + dtype, + layout, + requires_grad, + pin_memory, + memory_format, + process_group, + ): + current_rank = dist.get_rank(process_group) + sharding_dim = sharding_spec.dim + + # Validate the sharding spec. + if not isinstance(sharding_dim, int): + raise ValueError( + f"Sharding dim needs to be an integer, found: {sharding_dim}" + ) + if sharding_dim >= len(dims) or sharding_dim < -len(dims): + raise ValueError(f"Invalid sharding dim: {sharding_dim}") + + dim_size = dims[sharding_dim] + devices = sharding_spec.placements + chunks = len(devices) + # split_size computed similar to 'torch.chunk' + split_size = (dim_size + chunks - 1) // chunks + + for idx, device in enumerate(devices): + if not is_valid_device(device): + raise ValueError(f"{device} is not a valid device") + + rank, local_device = _parse_remote_device(device) # type: ignore[arg-type] + + # Validate rank. + if not isinstance(rank, int) or (rank < 0 or rank >= dist.get_world_size(process_group)): + raise ValueError(f'Invalid rank: {rank}') + + # Adjust the sharding dim for this rank. + sharded_dim_size = min(dim_size, split_size * (idx + 1)) - split_size * idx + + if sharded_dim_size > 0: + # Build sharding_metadata. + + # deepcopy for modification. + rank_dims = dims.copy() + + rank_offsets = [0] * len(dims) + rank_offsets[sharding_dim] = split_size * idx + rank_dims[sharding_dim] = sharded_dim_size + + shard_metadata = ShardMetadata(rank_offsets, rank_dims, device) + self._sharding_metadata.append(shard_metadata) + + # Build the local shard for the current rank if it is involved in the sharding spec. + if current_rank == rank: + # Initialize the local shard. + local_shard = torch.empty( + *rank_dims, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, + ) + + self._local_shards.append(Shard(local_shard, shard_metadata)) + + def sharding_spec(self) -> ShardingSpec: + """ + Returns the ShardingSpec for the tensor. + """ + return self._sharding_spec + + def __torch_function__(self, func, types, args=(), kwargs=None): + raise RuntimeError(f"torch function '{func.__name__}' not supported for ShardedTensor!") + + def sharding_metadata(self) -> List[ShardMetadata]: + """ + Returns a list of :class:`ShardeMetadata` objects corresponding to the + metadata for each shard. + """ + return self._sharding_metadata + + def local_shards(self) -> List[Shard]: + """ + Returns a list of :class:`Shard' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py new file mode 100644 index 00000000000000..f25c849559d7b9 --- /dev/null +++ b/torch/distributed/_sharding_spec/__init__.py @@ -0,0 +1,8 @@ +from .api import ( + ChunkShardingSpec, + DevicePlacementSpec, + EnumerableShardingSpec, + PlacementSpec, + ShardMetadata, + ShardingSpec, +) diff --git a/torch/distributed/_sharding_spec/_internals.py b/torch/distributed/_sharding_spec/_internals.py new file mode 100644 index 00000000000000..86dc7087477ce3 --- /dev/null +++ b/torch/distributed/_sharding_spec/_internals.py @@ -0,0 +1,22 @@ +import torch +from torch.distributed.utils import _parse_remote_device + +def is_valid_device(device): + """ + Checks if this is a valid local/remote device. + """ + # Check for torch.device + try: + torch.device(device) + return True + except Exception: + pass + + # Check for remote device. + try: + _parse_remote_device(device) + return True + except Exception: + pass + + return False diff --git a/torch/distributed/_sharding_spec/api.py b/torch/distributed/_sharding_spec/api.py new file mode 100644 index 00000000000000..1bc1219d7d7a6b --- /dev/null +++ b/torch/distributed/_sharding_spec/api.py @@ -0,0 +1,292 @@ +from abc import ABC +import torch +from typing import List, Union + +from ._internals import is_valid_device + +Device = Union[torch.device, int, str] + +class PlacementSpec(ABC): + """ + Base class representing the placement of an entity. Subclasses of this + class can be used to specify customized placements which might not be + covered by existing APIs. + """ + pass + +class DevicePlacementSpec(PlacementSpec): + """ + Associates placement of an entity with a single device. The device can be a + local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + Args: + device(str, :class:`torch.device`): The device to place the entity on. + """ + def __init__(self, device: Device): + super(DevicePlacementSpec, self).__init__() + if not is_valid_device(device): + raise ValueError(f'{device} is not a valid device') + self._device = device + + @property + def device(self) -> Device: + """ + Retrieves the device for placement. + """ + return self._device + + +class ShardingSpec(PlacementSpec): + """ + Base class representing sharding specifications. It is special type of + PlacementSpec. + """ + pass + + +class ChunkShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that defines the placement as being sharded + across multiple devices. In particular, it represents sharding a Tensor + along a single dimension into equal chunks (similar to :meth:`torch.chunk`). + + The semantics of how a tensor is partitioned is inline with + :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the + specified ``dim`` and ``chunks`` in torch.chunk is the number of elements + in the placement specified. + + Args: + dim (int or str): + The dimension to shard on, could be an integer representing the + dimension or a string in case of named tensors where dimensions are + named. + placement(List[Device] or List[PlacementSpec]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This + parameter can be a list of devices + (ex: ["rank:0/cuda:0", "rank:1/cuda:1"]) or a list of custom + placement specs. + + The device can be a local device or a remote device specified by one + of the following remote formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + """ + + ShardingDim = Union[int, str] + ShardPlacements = List[Union[Device, PlacementSpec]] + + def __init__(self, dim: ShardingDim, placements: ShardPlacements): + super(ChunkShardingSpec, self).__init__() + self._verify_dim(dim) + self._verify_devices(placements) + self._dim = dim + self._placements = placements + + @staticmethod + def _verify_devices(placements): + if placements is None or len(placements) == 0: + raise ValueError(f'None/Empty placement provided: {placements}') + for dev in placements: + if not isinstance(dev, PlacementSpec) and not is_valid_device(dev): + raise ValueError(f'{dev} is not a valid device') + + @staticmethod + def _verify_dim(dim): + if not (isinstance(dim, int) or isinstance(dim, str)): + raise ValueError(f'{dim} needs to either be an int or str') + + @property + def dim(self) -> ShardingDim: + """ + Retrieves the dimension to shard on. + """ + return self._dim + + @property + def placements(self) -> ShardPlacements: + """ + Retrieves the shard placements. + """ + return self._placements + +class ShardMetadata(object): + """ + Represents a shard of the overall Tensor including its + offsets, lengths and device placement. + + Args: + shard_offsets(List[int]): Offsets in the orignal tensor indicating + the start offsets for this shard. Should have the same rank as + the original tensor. + shard_lengths(List[int]): Lengths indicating the length of each + dimension for this shard. Should have the same rank as the + original tensor. + placement(List[Device or PlacementSpec]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This + parameter can be a list of devices + (ex: ["rank:0/cuda:0", "rank:1/cuda:1"]) or a list of custom + placement specs. + + The device can be a local device or a remote device specified by one + of the following remote formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + """ + + ShardPlacement = Union[Device, PlacementSpec] + + __slots__ = ['_shard_offsets', '_shard_lengths', '_placement'] + + def __init__( + self, + shard_offsets: List[int], + shard_lengths: List[int], + placement: ShardPlacement): + + if not isinstance(placement, PlacementSpec) and not is_valid_device(placement): + raise ValueError(f'{placement} is not a valid device') + + if len(shard_offsets) != len(shard_lengths): + raise ValueError( + f'shard_offsets and shard_lengths should have ' + f'the same number of elements, found {len(shard_offsets)} ' + f'and {shard_lengths} respectively') + + for i in range(len(shard_offsets)): + if shard_offsets[i] < 0: + raise ValueError('shard_offsets should be >=0') + if shard_lengths[i] <= 0: + raise ValueError('shard_lengths should be > 0') + + self._shard_offsets = shard_offsets + self._shard_lengths = shard_lengths + self._placement = placement + + def __repr__(self): + return ( + f'ShardMetadata(shard_offsets: {self._shard_offsets}, ' + f'shard_lengths: {self._shard_lengths}, placement: {self._placement})' + ) + + @property + def shard_offsets(self): + return self._shard_offsets + + @property + def shard_lengths(self): + return self._shard_lengths + + @property + def placement(self): + return self._placement + + +class EnumerableShardingSpec(ShardingSpec): + + def __init__(self, shards: List[ShardMetadata]): + """ + This is a type of PlacementSpec that allows users to specify a generic + sharding scheme by enumerating exactly how each shard is laid out. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. + """ + super(EnumerableShardingSpec, self).__init__() + if len(shards) == 0: + raise ValueError(f'Empty shard list provided: {shards}') + + # Validate each shard has same rank. + rank = -1 + for shard in shards: + if rank != -1 and rank != len(shard.shard_offsets): + raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}') + rank = len(shard.shard_offsets) + + self._validate_non_overlapping(shards) + + self._shards = shards + + @staticmethod + def _validate_non_overlapping(shards: List[ShardMetadata]): + """ + Ensures none of the shards overlap with each other. + """ + # TODO: evaluate optimizing this if needed. + for i in range(len(shards)): + for j in range(i + 1, len(shards)): + if EnumerableShardingSpec._check_shard_pair_overlap(shards[i], shards[j]): + raise ValueError(f'Shards {shards[i]} and {shards[j]} overlap') + + @staticmethod + def _check_shard_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): + """ + Checks if two shards overlap. + """ + + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.shard_offsets) + for i in range(ndims): + if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_lengths[i]: + return False + if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_lengths[i]: + return False + + return True + + @property + def shards(self): + return self._shards + + def check_tensor(self, tensor: torch.Tensor) -> None: + """ + Checks if the sharding spec is compatible with the provided tensor. + + Args: + tensor(torch.Tensor): Tensor to verify. + Raises: + ``ValueError`` if not compatible. + """ + + # If the tensor's volume matches the total volume of all shards and + # all shard boundaries are within tensor dims, we have a compatible + # sharding spec for this tensor. Note that we have already verified + # we don't have overlapping shards. + tensor_rank = len(tensor.size()) + shards_rank = len(self._shards[0].shard_offsets) + if tensor_rank != shards_rank: + raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}') + + total_shard_volume = 0 + tensor_dims = tensor.size() + for shard in self._shards: + shard_volume = 1 + for i, shard_length in enumerate(shard.shard_lengths): + shard_volume *= shard_length + if shard.shard_offsets[i] + shard.shard_lengths[i] > tensor_dims[i]: + raise ValueError( + f'Shard offset {shard.shard_offsets[i]} and length' + f'{shard.shard_lengths[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}') + total_shard_volume += shard_volume + + tensor_volume = 1 + for size in tensor_dims: + tensor_volume *= size + + if total_shard_volume != tensor_volume: + # TODO: Can we improve this error message to point out the gaps? + raise ValueError( + f'Total volume of shards: {total_shard_volume}' + f'does not match tensor volume: {tensor_volume}, in other words' + f' all the individual shards do not cover the entire tensor') diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c28664cf72abfb..c8e70fc14a46a7 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -30,6 +30,7 @@ from .constants import default_pg_timeout from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401 +from torch._C._distributed_c10d import _get_debug_mode, _DistributedDebugLevel _MPI_AVAILABLE = True _NCCL_AVAILABLE = True @@ -50,12 +51,15 @@ try: from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper except ImportError: _GLOO_AVAILABLE = False logger = logging.getLogger(__name__) +PG_WRAPPER_STORE_PREFIX = "pg_wrapper" + # Some reduce ops are not supported by complex numbers and will result in an error. # We currently provide complex support to the distributed API by viewing @@ -657,7 +661,28 @@ def _new_process_group_helper( if backend == Backend.GLOO: if pg_options is not None: raise RuntimeError("GLOO options not supported") - pg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout) + pg = ProcessGroupGloo( + prefix_store, + rank, + world_size, + timeout=timeout) + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debugability. + if _get_debug_mode() == _DistributedDebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info("""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""") + else: + pg = _create_process_group_wrapper( + wrapped_pg=pg, + store_prefix=group_name, + store=store, + rank=rank, + world_size=world_size, + timeout=timeout + ) _pg_map[pg] = (Backend.GLOO, store) _pg_names[pg] = group_name elif backend == Backend.NCCL: @@ -673,7 +698,28 @@ def _new_process_group_helper( pg_options.is_high_priority_stream = False pg_options._timeout = timeout - pg = ProcessGroupNCCL(prefix_store, rank, world_size, pg_options) + pg = ProcessGroupNCCL( + prefix_store, + rank, + world_size, + pg_options) + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debugability. + if _get_debug_mode() == _DistributedDebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info("""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""") + else: + pg = _create_process_group_wrapper( + wrapped_pg=pg, + store_prefix=group_name, + store=store, + rank=rank, + world_size=world_size, + timeout=timeout + ) _pg_map[pg] = (Backend.NCCL, store) _pg_names[pg] = group_name else: @@ -2538,7 +2584,7 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False): """ - Synchronizes all processes similar to torch.distributed.barrier, but takes + Synchronizes all processes similar to ``torch.distributed.barrier``, but takes a configurable timeout and is able to report ranks that did not pass this barrier within that timeout. Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. Rank 0 will block until all send @@ -2550,7 +2596,7 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals This collective will block all processes/ranks in the group, until the whole group exits the function successfully, making it useful for debugging and synchronizing. However, it can have a performance impact and should only - be used for debugging or scenarios that require full synhcronization points + be used for debugging or scenarios that require full synchronization points on the host-side. For debugging purposees, this barrier can be inserted before the application's collective calls to check if any ranks are desynchronized. @@ -2599,6 +2645,26 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals group_to_use = _get_default_group() if group is None else group return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks) +def _create_process_group_wrapper( + wrapped_pg: ProcessGroup, + store_prefix: str, + store: Store, + rank: int, + world_size: int, + timeout: timedelta = default_pg_timeout +): + # Create a separate prefix store for the helper process group. + prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" + store = PrefixStore(prefix, store) + helper_pg = ProcessGroupGloo( + store, + rank, + world_size, + timeout=timeout + ) + # Wrap the underlying pg with ProcessGroupWrapper. + wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) + return wrapped_pg def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None): """ diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 73bded145a89a1..87761fca287a7e 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -11,6 +11,7 @@ Optional, Set, Tuple, + Type, TypeVar, Union, ) @@ -20,7 +21,7 @@ from torch import Tensor, device, dtype, nn from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.internal import _internal_rpc_pickler -from torch.distributed.rpc.utils import _parse_remote_device +from torch.distributed.utils import _parse_remote_device from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle @@ -110,7 +111,7 @@ class _RemoteModule(nn.Module): def __init__( self, remote_device: str, - module_cls: nn.Module, + module_cls: Type[nn.Module], args: Tuple = None, kwargs: Dict[str, Any] = None, _module_interface_cls: Any = None, @@ -154,8 +155,12 @@ def __init__( Args: remote_device (str): Device on the destination worker where we'd like to place this module. - The format should be "/", where the device field can be parsed as torch.device type. - E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): diff --git a/torch/distributed/rpc/utils.py b/torch/distributed/rpc/utils.py deleted file mode 100644 index afdde21f3c5690..00000000000000 --- a/torch/distributed/rpc/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -def _parse_remote_device(remote_device: str): - r""" - Parses the remote device. - - Args: - remote_device (str): Device on the destination worker where we'd like to place this module. - The format should be "/", where the device field can be parsed as torch.device type. - E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". - In addition, the device field can be optional and the default value is "cpu". - - Returns: - A workername and a device. - """ - fields = remote_device.split("/") - if len(fields) == 2: - [on, device] = fields - elif len(fields) == 1: - on = fields[0] - device = "cpu" - else: - raise RuntimeError( - "Could not parse remote_device: {}. The valid format is '/'".format( - remote_device - ) - ) - - # Since the workername in the input remote device won't be validated until the created remote module is executed, - # only do some very basic sanity check on workername at the module creation time. - # As currently there is no regex to describe the format of workername, just check whether the workername is empty. - if not on: - raise RuntimeError( - "The workername in remote_device '{}' cannot be empty. The valid format is '/'".format( - remote_device - ) - ) - - return on, device diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py new file mode 100644 index 00000000000000..323b3608023abd --- /dev/null +++ b/torch/distributed/utils.py @@ -0,0 +1,56 @@ +import torch + +def _parse_remote_device(remote_device: str): + r""" + Parses the remote device. + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The format should be one of the following: + + 1. "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + 2. "rank:/", where is the rank of the + process and device can be parsed as torch.device type. + E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" + + Returns: + A workername/rank and a device. + """ + + PARSE_ERROR = ( + f"Could not parse remote_device: {remote_device}. The valid format is " + "'/' or 'rank:/'" + ) + + fields = remote_device.split("/") + if len(fields) == 2: + [on, device] = fields + elif len(fields) == 1: + on = fields[0] + device = "cpu" + else: + raise ValueError(PARSE_ERROR) + + # Since the workername in the input remote device won't be validated until the created remote module is executed, + # only do some very basic sanity check on workername at the module creation time. + # As currently there is no regex to describe the format of workername, just check whether the workername is empty. + if not on: + raise ValueError(PARSE_ERROR) + + # Validate the device. + torch.device(device) + + # Check for rank based format + fields = on.split(':') + if len(fields) == 2: + # rank:/device format, extract rank + if fields[0] == 'rank' and fields[1].isdigit(): + on = int(fields[1]) # type: ignore[assignment] + else: + raise ValueError(PARSE_ERROR) + elif len(fields) > 2: + raise ValueError(PARSE_ERROR) + + return on, device diff --git a/torch/functional.py b/torch/functional.py index dc92caa60add76..0ca423e717b4cc 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -155,7 +155,7 @@ def split(tensor, split_size_or_sections, dim=0): return tensor.split(split_size_or_sections, dim) -def einsum(equation, *operands): +def einsum(*args): r"""einsum(equation, *operands) -> Tensor Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation @@ -171,7 +171,7 @@ def einsum(equation, *operands): Equation: - The :attr:`equation` string specifies the subscripts (lower case letters `['a', 'z']`) for each dimension of + The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is @@ -211,7 +211,7 @@ def einsum(equation, *operands): Args: equation (string): The subscripts for the Einstein summation. - operands (Tensor): The operands to compute the Einstein sum of. + operands (List[Tensor]): The tensors to compute the Einstein summation of. Examples:: @@ -259,8 +259,43 @@ def einsum(equation, *operands): tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]]) """ + if len(args) < 2: + raise ValueError('einsum(): must specify the equation string and at least one operand, ' + 'or at least one operand and its subscripts list') + + equation = None + operands = None + + if isinstance(args[0], torch.Tensor): + # Convert the subscript list format which is an interleaving of operand and its subscripts + # list with an optional output subscripts list at the end (see documentation for more details on this) + # to the equation string format by creating the equation string from the subscripts list and grouping the + # input operands into a tensorlist (List[Tensor]). + def parse_subscript(n: int) -> str: + if n == Ellipsis: + return '...' + if n >= 0 and n < 26: + return chr(n + ord('a')) + if n >= 26 and n < 52: + return chr(n - 26 + ord('A')) + raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)') + + # Parse subscripts for input operands + equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2]) + + # Parse optional output subscripts (provided when the number of arguments is odd) + if len(args) % 2 == 1: + equation += '->' + ''.join(parse_subscript(s) for s in args[-1]) + operands = args[:-1:2] + else: + operands = args[::2] + else: + equation = args[0] + operands = args[1:] + if has_torch_function(operands): return handle_torch_function(einsum, operands, equation, *operands) + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument _operands = operands[0] diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 92c63965cdb0be..5594b43d0bf7cb 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -122,6 +122,7 @@ def create_name(self, candidate: str, obj: Optional[Any]) -> str: # delete all characters that are illegal in a Python identifier candidate = self._illegal_char_regex.sub('_', candidate) + if candidate[0].isdigit(): candidate = f'_{candidate}' @@ -826,6 +827,7 @@ def add_global(name_hint: str, obj: Any): # normalize the name hint to get a proper identifier global_name = namespace.create_name(name_hint, obj) + if global_name in globals_: assert globals_[global_name] is obj return global_name @@ -837,6 +839,10 @@ def add_global(name_hint: str, obj: Any): add_global(name, obj) def type_repr(o : Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + typename = _type_repr(o) # This is a generic type, e.g. typing.List[torch.Tensor] @@ -846,6 +852,7 @@ def type_repr(o : Any): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] + return f'{origin_typename}[{",".join(args)}]' # Common case: this is a regular module name like 'foo.bar.baz' diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 201df4743606b7..43ee0a36754cb4 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -52,7 +52,6 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: return _custom_builtins[name].import_str if _is_from_torch(name): return 'import torch' - module_name, attr_name = importer.get_name(obj) return f'from {module_name} import {attr_name} as {name}' diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 6b83eb2a0787a9..014da33d89e3f9 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import hashlib import torch import torch.fx import pydot @@ -9,13 +10,30 @@ _COLOR_MAP = { "placeholder": '"AliceBlue"', "call_module": "LemonChiffon1", - "call_function": "PeachPuff1", "get_param": "Yellow2", "get_attr": "LightGrey", - "call_method": "LavenderBlush1", "output": "PowderBlue", } +_HASH_COLOR_MAP = [ + "CadetBlue1", + "Coral", + "DarkOliveGreen1", + "DarkSeaGreen1", + "GhostWhite", + "Khaki1", + "LavenderBlush1", + "LightSkyBlue", + "MistyRose1", + "MistyRose2", + "PaleTurquoise2", + "PeachPuff1", + "Salmon", + "Thistle1", + "Thistle3", + "Wheat1", +] + _WEIGHT_TEMPLATE = { "shape": "record", "fillcolor": "Salmon", @@ -64,7 +82,13 @@ def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: "style": '"filled,rounded"', "fontcolor": "#000000", } - template["fillcolor"] = _COLOR_MAP[node.op] + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + # Use a random color for each node; based on its name so it's stable. + target_name = node._pretty_print_target(node.target) + target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) + template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] return template def _get_leaf_node( @@ -106,16 +130,18 @@ def _get_node_label(self, module: torch.fx.GraphModule, node: torch.fx.Node) -> label += "|" + self._typename(node.target) + r"\l" tensor_meta = node.meta.get('tensor_meta') - if tensor_meta: - dtype_ = tensor_meta.dtype if hasattr(tensor_meta, "dtype") else "none" - shape_ = tensor_meta.shape if hasattr(tensor_meta, "shape") else "none" - stride_ = tensor_meta.stride if hasattr(tensor_meta, "stride") else "none" - if dtype_: - label += "|" + "dtype" + "=" + str(dtype_) + r"\l" - if shape_: - label += "|" + "shape" + "=" + str(shape_) + r"\l" - if stride_: - label += "|" + "stride" + "=" + str(stride_) + r"\l" + if tensor_meta is not None: + label += "|" + "dtype" + "=" + str(tensor_meta.dtype) + r"\l" + label += "|" + "shape" + "=" + str(tuple(tensor_meta.shape)) + r"\l" + label += "|" + "stride" + "=" + str(tensor_meta.stride) + r"\l" + if tensor_meta.is_quantized: + if tensor_meta.qscheme in { + torch.per_tensor_affine, + torch.per_tensor_symmetric, + }: + label += "|" + "q_scale" + "=" + str(tensor_meta.q_scale) + r"\l" + label += "|" + "q_zero_point" + "=" + str(tensor_meta.q_zero_point) + r"\l" + label += "|" + "qscheme" + "=" + str(tensor_meta.qscheme) + r"\l" return label + "}" diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index bb5a45ea41acdd..cd260c819884a1 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -29,6 +29,7 @@ interface, CompilationUnit, ScriptFunction, + _ScriptProfile, _unwrap_optional, ) from torch.jit._trace import ( diff --git a/torch/jit/_script.py b/torch/jit/_script.py index f50215203443f6..3d4e8ed86e9364 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -43,6 +43,7 @@ JitTypeTraceConfig , JitTypeTraceStore ) +from torch._classes import classes type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType @@ -1262,6 +1263,111 @@ def _recursive_compile_class(obj, loc): CompilationUnit = torch._C.CompilationUnit set_module(CompilationUnit, "torch.jit") + +def pad(s: str, padding: int, offset: int = 0, char: str = ' '): + if padding >= len(s): + padding -= len(s) + return ''.join([char for _ in range(padding + offset)]) + s + + +class _ScriptProfileColumn: + def __init__(self, header: str, alignment: int = 4, offset: int = 0): + self.header = header + self.alignment = alignment + self.offset = offset + self.rows: Dict[int, Any] = {} + + def add_row(self, lineno: int, value: Any): + self.rows[lineno] = value + + def materialize(self): + max_length = len(self.header) + rows: List[Tuple[int, str]] = [] + for (key, value) in self.rows.items(): + cell = str(value) + rows.append((key, cell)) + max_length = max(len(cell), max_length) + + if self.alignment > 0: + padding = max_length + self.alignment + padding -= padding % self.alignment + else: + padding = 0 + + rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows] + return pad(self.header, padding, self.offset), rows + + +class _ScriptProfileTable: + def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]): + self.cols = cols + self.source_range = source_range + + def dump_string(self): + outputs: List[str] = [] + cells: List[Tuple[str, Dict[int, str]]] = [] + header_buffer = '' + for col in self.cols: + header, rows = col.materialize() + header_buffer += header + cells.append((header, dict(rows))) + + outputs.append(header_buffer) + outputs.append(pad('', len(header_buffer), 0, '=')) + for line in self.source_range: + row_buffer = '' + for header, rows in cells: + cell = rows.get(line) + if cell is None: + row_buffer += pad('', len(header)) + else: + row_buffer += cell + outputs.append(row_buffer) + return '\n'.join(outputs) + + +class _ScriptProfile: + def __init__(self): + self.profile = classes.profiling._ScriptProfile() + + def enable(self): + self.profile.enable() + + def disable(self): + self.profile.disable() + + def dump_string(self) -> str: + outputs: List[str] = [] + for source_stats in self.profile._dump_stats(): + source_ref = source_stats.source() + source_lines = source_ref.text().splitlines() + dedent = min([len(line) - len(line.lstrip(' ')) for line in source_lines]) + source_lines = [line[dedent:] for line in source_lines] + + start_line = source_ref.starting_lineno() + end_line = start_line + len(source_lines) + source_range = range(start_line, end_line) + lineno = _ScriptProfileColumn("Line #") + hits = _ScriptProfileColumn("Hits") + time_ns = _ScriptProfileColumn("Time (ns)") + line_contents = _ScriptProfileColumn("Line Contents", 0, 1) + stats = source_stats.line_map() + for line in source_range: + lineno.add_row(line, line) + line_contents.add_row(line, source_lines[line - start_line]) + stat = stats.get(line) + if stat is not None: + hits.add_row(line, stat.count()) + time_ns.add_row(line, stat.duration_ns()) + + table = _ScriptProfileTable([lineno, hits, time_ns, line_contents], list(source_range)) + outputs.append(table.dump_string()) + return '\n\n'.join(outputs) + + def dump(self): + print(self.dump_string()) + + def _unwrap_optional(x): assert x is not None, "Unwrapping null optional" return x diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index cc48b13b2795ce..ced99581b6f357 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -300,6 +300,9 @@ def try_ann_to_type(ann, loc): if inspect.isclass(ann) and is_tensor(ann): return TensorType.get() if is_tuple(ann): + # Special case for the empty Tuple type annotation `Tuple[()]` + if len(ann.__args__) == 1 and ann.__args__[0] == (): + return TupleType([]) return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 49d2bfad24f3f5..a78f8ab507fd1b 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -19,7 +19,7 @@ ) from torch._utils_internal import get_source_lines_and_file from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name -from torch._jit_internal import SourceContext, should_drop, is_static_fn, FunctionModifiers # noqa: F401 +from torch._jit_internal import make_source_context, should_drop, is_static_fn, FunctionModifiers # noqa: F401 import torch.jit.annotations _IS_ASTUNPARSE_INSTALLED = False @@ -207,7 +207,7 @@ def is_classmethod(fn): dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) - ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False) + ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False) class_ast = py_ast.body[0] assert isinstance(class_ast, ast.ClassDef) assigns = get_class_assigns(ctx, class_ast) @@ -275,7 +275,7 @@ def _forward(self): raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) - ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) + ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] if is_classmethod: @@ -386,7 +386,8 @@ def build_param_list(ctx, py_args, self_name, pdt_arg_types=None): arg_and_types_kwonlyargs = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) for arg in py_args.kwonlyargs] - result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) for arg, arg_type in arg_and_types] + result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) + for arg, arg_type in arg_and_types] result += [build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type) for arg, arg_type in arg_and_types_kwonlyargs] return result @@ -892,7 +893,6 @@ def build_ExtSlice(ctx, base, extslice): "slicing multiple dimensions with " "{} not supported".format(sub_type)) return sub_exprs - base = build_expr(ctx, expr.value) sub_type = type(expr.slice) if sub_type is ast.Index: @@ -900,6 +900,15 @@ def build_ExtSlice(ctx, base, extslice): # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] # XXX: Indexing using a list is **different**! It triggers advanced indexing. indices = [build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts] + if not indices: + # `col_offset` is an int, but `end_col_offset` is + # `Optional[int]`. The magic number is here to make + # sure we can parse `()` on any machine + r = ctx.make_range(expr.lineno, + expr.slice.value.col_offset, + expr.slice.value.col_offset + 2) + tup = TupleLiteral(r, []) + indices.append(tup) return Subscript(base, indices) else: return Subscript(base, [build_expr(ctx, expr.slice.value)]) @@ -916,6 +925,14 @@ def build_ExtSlice(ctx, base, extslice): indices.append(build_SliceExpr(ctx, base, index_expr)) else: indices.append(build_expr(ctx, index_expr)) + # Special-case logic for `typing.Tuple[()]` + if not indices: + # See note above r.e. magic number + r = ctx.make_range(expr.lineno, + expr.slice.col_offset, + expr.slice.col_offset + 2) + tup = TupleLiteral(r, []) + indices.append(tup) return Subscript(base, indices) return Subscript(base, [build_expr(ctx, expr.slice)]) else: # Ellipsis (can only happen in Python 2) diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index 6aba4852c64e53..ef77e78c8db2bf 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -145,3 +145,43 @@ def _backport_for_mobile_to_buffer(f_input, to_version): return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version) # type: ignore[attr-defined] else: return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version) # type: ignore[attr-defined] + +def _get_model_ops_and_info(f_input): + r""" + A function to retrieve the root (top level) operators of a model and their corresponding + compatibility info. These root operators can call other operators within them (traced ops), and + a root op can call many different traced ops depending on internal code paths in the root op. + These traced ops are not returned by this function. Those operators are abstracted into the + runtime as an implementation detail (and the traced ops themselves can also call other operators) + making retrieving them difficult and their value from this api negligible since they will differ + between which runtime version the model is run on. Because of this, there is a false positive this + api can't prevent in a compatibility usecase. All the root ops of a model are present in a + target runtime, but not all the traced ops are which prevents a model from being able to run. + Args: + f_input: a file-like object (has to implement read, readline, tell, and seek), + or a string containing a file name + + Returns: + Operators and info: A Dictionary mapping strings (the qualified names of the root operators) + of the model to their OperatorInfo structs. + + Example: + + .. testcode:: + + from torch.jit.mobile import _get_model_ops_and_info + + # Get bytecode version from a saved file path + ops_and_info = _get_model_ops_and_info("path/to/model.ptl") + + """ + if isinstance(f_input, str): + if not os.path.exists(f_input): + raise ValueError(f"The provided filename {f_input} does not exist") + if os.path.isdir(f_input): + raise ValueError(f"The provided filename {f_input} is a directory") + + if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)): + return torch._C._get_model_ops_and_info(str(f_input)) # type: ignore[attr-defined] + else: + return torch._C._get_model_ops_and_info(f_input.read()) # type: ignore[attr-defined] diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index 3c11526f0c3b03..4ee7834a2fa8d8 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -74,7 +74,7 @@ if(USE_C10D_MPI) endif() if(USE_C10D_GLOO) - list(APPEND C10D_SRCS ProcessGroupGloo.cpp GlooDeviceFactory.cpp) + list(APPEND C10D_SRCS ProcessGroupGloo.cpp GlooDeviceFactory.cpp ProcessGroupWrapper.cpp) list(APPEND C10D_LIBS gloo) if(USE_CUDA) list(APPEND C10D_LIBS gloo_cuda) @@ -137,6 +137,7 @@ copy_header(sequence_num.hpp) if(USE_GLOO) copy_header(ProcessGroupGloo.hpp) copy_header(GlooDeviceFactory.hpp) + copy_header(ProcessGroupWrapper.hpp) endif() if(NOT WIN32) copy_header(HashStore.hpp) diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index d688830e09d466..68b0bd6ddc009d 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -2833,7 +2833,10 @@ void ProcessGroupGloo::monitoredBarrier( } // If we are collecting all failed ranks, check if we need to throw if // some ranks have not responded. - if (waitAllRanks && processedRanks.size() != size_) { + // Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully + // processed. + auto rankFailure = (processedRanks.size() != size_ - 1); + if (waitAllRanks && rankFailure) { std::vector failedRanks; for (int i = 1; i < size_; ++i) { if (std::find(processedRanks.begin(), processedRanks.end(), i) == @@ -2889,10 +2892,9 @@ void ProcessGroupGloo::setSequenceNumberForGroup() { } uint64_t ProcessGroupGloo::getSequenceNumberForGroup() { - TORCH_CHECK( - sequenceNum_ != c10::nullopt, - "Sequence number is not set for rank ", - rank_); + if (sequenceNum_ == c10::nullopt) { + return 0; + } return sequenceNum_->get(); } diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 1de1305e6c00f0..bbffe6737c9b6e 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -505,10 +505,9 @@ void ProcessGroupNCCL::setSequenceNumberForGroup() { } uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { - TORCH_CHECK( - sequenceNum_ != c10::nullopt, - "Sequence number is not set for rank ", - rank_); + if (sequenceNum_ == c10::nullopt) { + return 0; + } return sequenceNum_->get(); } diff --git a/torch/lib/c10d/ProcessGroupWrapper.cpp b/torch/lib/c10d/ProcessGroupWrapper.cpp new file mode 100644 index 00000000000000..885a484e446410 --- /dev/null +++ b/torch/lib/c10d/ProcessGroupWrapper.cpp @@ -0,0 +1,322 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + +namespace { +// A container for information about a particular collective, including optype +// and input tensors (if applicable.) +struct CollectiveFingerPrint { + // Current collective's operation type. + OpType op_type_; + // Ref to input tensors, if given, of the collective. If given, shapes will be + // checked across processes to ensure valid input into the collective. + const std::vector& input_tensors_; + explicit CollectiveFingerPrint( + OpType op_type, + const std::vector& input_tensors) + : op_type_(op_type), input_tensors_(input_tensors) {} + + // Verifies a given int is the same across processes. + void verify_num( + int64_t value, + const c10::intrusive_ptr& pg, + const std::string& failureMsg) { + auto tensor = at::full({1}, value, at::TensorOptions().dtype(at::kLong)); + std::vector tensors; + tensors.reserve(pg->getSize()); + for (int i = 0; i < pg->getSize(); ++i) { + tensors.emplace_back(at::zeros_like(tensor)); + } + std::vector> out_tensors({tensors}); + std::vector inp_tensors({tensor}); + pg->allgather(out_tensors, inp_tensors)->wait(); + std::unordered_set gathered; + for (const auto& t : out_tensors[0]) { + auto n = t.item().to(); + gathered.insert(n); + if (gathered.size() > 1) { + TORCH_CHECK(false, failureMsg); + } + } + } + + // Verifies that shapes are consistent across processes. + // shape_tensors_to_report should be specified as the tensors to report when a + // shape inconsistency is found. This is not necessarily shape_tensors such as + // in the case we are checking shape dimensionality. + void verify_shapes( + std::vector shape_tensors, + std::vector shape_tensors_to_report, + c10::intrusive_ptr& pg) { + std::vector> output_tensors; + output_tensors.reserve(shape_tensors.size()); + for (auto & tensor_shape : shape_tensors) { + std::vector outputs; + outputs.reserve(pg->getSize()); + for (int i = 0; i < pg->getSize(); ++i) { + outputs.emplace_back(at::zeros_like(tensor_shape)); + } + output_tensors.emplace_back(outputs); + } + // Allgather tensor shapes. + pg->allgather(output_tensors, shape_tensors)->wait(); + // Verify equivalence + for (int i = 0; i < output_tensors.size(); ++i) { + auto world_tensor_shapes = output_tensors[i]; + auto reference_shape_tensor = shape_tensors[i]; + for (const auto& rank_tensor_shape : world_tensor_shapes) { + if (!rank_tensor_shape.equal(reference_shape_tensor)) { + TORCH_CHECK( + false, + c10::str( + "Error when verifying shape tensors for collective ", + opTypeToString(op_type_), + " on rank ", + pg->getRank(), + ". This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: ", + shape_tensors_to_report)); + } + } + } + } + + // Executes and verifies the collective fingerprint. + void verify(c10::intrusive_ptr pg) { + // For collectives, all ranks should participate and call into them in the + // same order. Verify the same operation type is being requested. + int64_t op_type_int = static_cast(op_type_); + verify_num( + op_type_int, + pg, + c10::str( + "Mismatch between collective operation types across ranks.", + "This likely indicates an application bug where different ranks are ", + "calling different collectives. ", + "Rank ", + pg->getRank(), + " is calling collective: ", + opTypeToString(op_type_))); + // Retrieve input tensor shapes. + std::vector shape_tensors = + c10d::getTensorShapes(input_tensors_); + // If input_tensors_ is empty we would get no shape tensors back, but still + // do verification in case input_tensors_.empty() is + // inconsistent across ranks. In this case, sub in a single zeros tensor and + // ensure all ranks agree, because gloo pg does not allow collectives with + // empty inputs. + if (shape_tensors.size() == 0) { + shape_tensors = {at::zeros(1)}; + } + // Verify dimensionality of shapes. This catches errors where tensor shapes + // have different dimensions such as torch.randn(2, 3) vs torch.randn(2, 3, + // 4). If we did not do this step and instead proceeded directly with + // verifying tensor shapes, we would have malformed input into allgather() + // and crash with an unhelpful error. + std::vector meta_shape_tensors = + c10d::getTensorShapes(shape_tensors); + + verify_shapes( + meta_shape_tensors, /* shape_tensors_to_report= */ shape_tensors, pg); + + // If all meta shapes are 0 then we can skip the below verification since + // it is not possible that there would be a difference. This happens only + // when the tensor wraps a single scalar. + bool skip = true; + for (auto & t : meta_shape_tensors) { + if (t.item().to() != 0) { + skip = false; + break; + } + } + if (!skip) { + verify_shapes( + shape_tensors, /* shape_tensors_to_report= */ shape_tensors, pg); + } + } +}; +} // namespace + +ProcessGroupWrapper::ProcessGroupWrapper( + c10::intrusive_ptr pg, + c10::intrusive_ptr glooPg) + : ProcessGroup(pg->getRank(), pg->getSize()), pg_(pg), glooPg_(glooPg) { + // Set the sequence number for the underlying process group. + pg_->setSequenceNumberForGroup(); +} + +const std::string ProcessGroupWrapper::getBackendName() const { + return pg_->getBackendName(); +} + +c10::intrusive_ptr ProcessGroupWrapper::broadcast( + std::vector& data, + const BroadcastOptions& opts) { + runCollectiveChecks(OpType::BARRIER, data); + return pg_->broadcast(data, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::allreduce( + std::vector& data, + const AllreduceOptions& opts) { + runCollectiveChecks(OpType::ALLREDUCE, data); + return pg_->allreduce(data, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + // NOTE: We don't enforce shape checking for allreduce_coalesced because + // the implementation itself does not enforce it we have tests that use + // inconsistent shapes, see python implementation in distributed_c10d for + // details. + runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {}); + return pg_->allreduce_coalesced(tensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + runCollectiveChecks(OpType::REDUCE, tensors); + return pg_->reduce(tensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + runCollectiveChecks(OpType::ALLGATHER, inputTensors); + return pg_->allgather(outputTensors, inputTensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::_allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts) { + std::vector inputTensors({inputBuffer}); + runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors); + return pg_->_allgather_base(outputBuffer, inputBuffer, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts) { + // NOTE: We don't enforce shape checking for allgather_coalesced because + // the implementation itself does not enforce it we have tests that use + // inconsistent shapes, see python implementation in distributed_c10d for + // details. + runCollectiveChecks(OpType::ALLGATHER_COALESCED, {}); + return pg_->allgather_coalesced(outputTensorLists, inputTensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + runCollectiveChecks(OpType::GATHER, inputTensors); + return pg_->gather(outputTensors, inputTensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + runCollectiveChecks(OpType::SCATTER, outputTensors); + return pg_->scatter(outputTensors, inputTensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors); + return pg_->reduce_scatter(outputTensors, inputTensors, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts) { + // alltoall supports uneven split, so don't enforce shape checking. + runCollectiveChecks(OpType::ALLTOALL_BASE, {}); + return pg_->alltoall_base( + outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts) { + // alltoall supports uneven split, so don't enforce shape checking. + runCollectiveChecks(OpType::ALLTOALL, {}); + return pg_->alltoall(outputTensors, inputTensors, opts); +} + +void ProcessGroupWrapper::monitoredBarrier( + const BarrierOptions& opts, + bool waitAllRanks) { + return pg_->monitoredBarrier(opts, waitAllRanks); +} + +void ProcessGroupWrapper::setSequenceNumberForGroup() { + // Set underlying pg's sequence number if it is not set. + if (pg_->getSequenceNumberForGroup() == 0) { + // Set the sequence number for the underlying process group. + pg_->setSequenceNumberForGroup(); + } +} + +uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() { + return pg_->getSequenceNumberForGroup(); +} + +c10::intrusive_ptr ProcessGroupWrapper::send( + std::vector& tensors, + int dstRank, + int tag) { + return pg_->send(tensors, dstRank, tag); +} + +c10::intrusive_ptr ProcessGroupWrapper::recv( + std::vector& tensors, + int srcRank, + int tag) { + return pg_->recv(tensors, srcRank, tag); +} + +c10::intrusive_ptr ProcessGroupWrapper::recvAnysource( + std::vector& tensors, + int tag) { + return pg_->recvAnysource(tensors, tag); +} + +c10::intrusive_ptr ProcessGroupWrapper::barrier( + const BarrierOptions& opts) { + runCollectiveChecks(OpType::BARRIER, {}); + return pg_->barrier(opts); +} + +void ProcessGroupWrapper::runCollectiveChecks( + OpType op_type, + const std::vector& tensors) const { + // first perform a monitored barrier to ensure all ranks can synchronize. + c10d::BarrierOptions options; + // TODO: we should use wrapped pg_'s timeout here, but C++ ProcessGroup API + // does not expose timeout. + glooPg_->monitoredBarrier(options, /* waitAllRanks */ true); + auto finger_print = CollectiveFingerPrint(op_type, tensors); + // Will throw if an ill-formed collective is detected. + finger_print.verify(glooPg_); +} + +} // namespace c10d diff --git a/torch/lib/c10d/ProcessGroupWrapper.hpp b/torch/lib/c10d/ProcessGroupWrapper.hpp new file mode 100644 index 00000000000000..ea80ea04a82f06 --- /dev/null +++ b/torch/lib/c10d/ProcessGroupWrapper.hpp @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10d { + +class ProcessGroupWrapper : public ProcessGroup { + public: + explicit ProcessGroupWrapper( + c10::intrusive_ptr pg, + c10::intrusive_ptr glooPg); + + const std::string getBackendName() const override; + + c10::intrusive_ptr broadcast( + std::vector& data, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr allreduce( + std::vector& data, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + // This function is deprecated and will be moved out of ProcessGroup to comms: + // * do not add dependencies on this function, + // * do not implement it in your ProcessGroup, implement _allgather_base + // instead. + c10::intrusive_ptr allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false) + override; + + // Agrees on an initial sequence number for the whole group by having rank 0 + // create it and broadcast it to other ranks using the store. Only implemented + // for GLOO and NCCL backends currently. + // dont implement this + void setSequenceNumberForGroup() override; + + // Retrieves the current sequence number for the whole group, which should be + // in sync. If the returned number is not consistent across the group, it + // may indicate that there is some sort of collective desynchronization. + uint64_t getSequenceNumberForGroup() override; // just call underlying + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + c10::intrusive_ptr recvAnysource( + std::vector& tensors, + int tag) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + private: + // Underlying process group that actual application collectives will be + // dispatched to + c10::intrusive_ptr pg_; + // Gloo process group responsible for internal coordination such as monitored + // barrier, sequence number checking, collective fingerprint collecting. + c10::intrusive_ptr glooPg_; + // Conducts several checks to ensure that the underlying collective is well + // formed with the goal of notifying the user about incorrect collective use + // in the application. + void runCollectiveChecks( + OpType op_type, + const std::vector& tensors) const; +}; +} // namespace c10d diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp index 0d70d5a8d8d6a1..f8adc58746c663 100644 --- a/torch/lib/c10d/Utils.cpp +++ b/torch/lib/c10d/Utils.cpp @@ -71,6 +71,21 @@ namespace c10d { return it->second; } +std::vector getTensorShapes(const std::vector& tensors) { + std::vector shapeTensors; + shapeTensors.reserve(tensors.size()); + for (const auto& tensor : tensors) { + auto shapesVec = tensor.sizes().vec(); + int64_t shapes_size = shapesVec.size(); + // Need to clone here otherwise the shapesVec.data() memory is not copied + // and can be released under the hood. + at::Tensor shapesTensor = at::from_blob( + shapesVec.data(), {shapes_size}, at::TensorOptions().dtype(at::kLong)).clone(); + shapeTensors.emplace_back(std::move(shapesTensor)); + } + return shapeTensors; +} + namespace tcputil { diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp index 60314451e1e4e0..06380c1c34a674 100644 --- a/torch/lib/c10d/Utils.hpp +++ b/torch/lib/c10d/Utils.hpp @@ -48,6 +48,9 @@ std::string parse_env(const char* env_var_name); DistributedDebugLevel parseDistDebugLevel(); +// Retrieve tensor shapes from a given tensor. +std::vector getTensorShapes(const std::vector& tensors); + // Turns at::IntArrayRef into "(1, 2, 3, 4)". inline std::string toString(at::IntArrayRef l) { std::stringstream ss; diff --git a/torch/lib/c10d/logger.cpp b/torch/lib/c10d/logger.cpp index 4e5fbd2793811a..e165491dbdf6f0 100644 --- a/torch/lib/c10d/logger.cpp +++ b/torch/lib/c10d/logger.cpp @@ -241,8 +241,7 @@ void Logger::set_runtime_stats_and_log() { // unused_parameters_ is calculated in forward call of // each iteration. for (const auto& unused_index : reducer_->unused_parameters_) { - const auto& v = reducer_->replicas_[unused_index.replica_index] - [unused_index.variable_index]; + const auto& v = reducer_->replicas_[0][unused_index]; ddp_logging_data_->ints_map["unused_parameter_size"] += v.numel() * v.element_size(); } @@ -261,12 +260,11 @@ void Logger::set_runtime_stats_and_log() { if (reducer_->replicas_[0][0].is_cuda()) { #ifdef USE_CUDA - // Cuda time stats are only collected for single process single - // device and single device module. - if (reducer_->replicas_.size() > 1 || reducer_->is_multi_device_module_) { + // Cuda time stats are only collected for single device modules. + if (reducer_->is_multi_device_module_) { TORCH_WARN_ONCE( - "Cuda time stats are not collected for single process " - "multiple device program or multi-device modules."); + "Cuda time stats are not collected for multi-device modules." + ); return; } // Check events on the replicas_[0][0].device(). diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index 4ebc2c0020a975..1bdd7459aae0d0 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -97,65 +97,63 @@ Reducer::Reducer( { const auto replica_count = replicas_.size(); grad_accumulators_.resize(replica_count); - for (size_t replica_index = 0; replica_index < replica_count; - replica_index++) { - const auto variable_count = replicas_[replica_index].size(); - grad_accumulators_[replica_index].resize(variable_count); - for (size_t variable_index = 0; variable_index < variable_count; - variable_index++) { - auto& variable = replicas_[replica_index][variable_index]; - const auto index = VariableIndex(replica_index, variable_index); - - // The gradient accumulator function is lazily initialized once. - // Therefore we can use its presence in the autograd graph as - // evidence that the parameter has participated in an iteration. - auto grad_accumulator = - torch::autograd::impl::grad_accumulator(variable); + // TODO: get rid of replica_index and nested + // containers such as replicas_, grad_accumulators_, etc. + size_t replica_index = 0; + const auto variable_count = replicas_[replica_index].size(); + grad_accumulators_[replica_index].resize(variable_count); + for (size_t variable_index = 0; variable_index < variable_count; + variable_index++) { + auto& variable = replicas_[replica_index][variable_index]; + + // The gradient accumulator function is lazily initialized once. + // Therefore we can use its presence in the autograd graph as + // evidence that the parameter has participated in an iteration. + auto grad_accumulator = torch::autograd::impl::grad_accumulator(variable); #ifndef _WIN32 - using torch::distributed::autograd::ThreadLocalDistAutogradContext; + using torch::distributed::autograd::ThreadLocalDistAutogradContext; #endif - // Hook to execute after the gradient accumulator has executed. - hooks_.emplace_back( - grad_accumulator->add_post_hook( - torch::make_unique( - [=](const torch::autograd::variable_list& outputs, - const torch::autograd::variable_list& /* unused */) { + // Hook to execute after the gradient accumulator has executed. + hooks_.emplace_back( + grad_accumulator->add_post_hook( + torch::make_unique( + [=](const torch::autograd::variable_list& outputs, + const torch::autograd::variable_list& /* unused */) { #ifndef _WIN32 - this->rpc_context_.set( - ThreadLocalDistAutogradContext::getContextPtr()); + this->rpc_context_.set( + ThreadLocalDistAutogradContext::getContextPtr()); #endif - this->autograd_hook(index); - return outputs; - })), - grad_accumulator); - - // Map raw function pointer to replica index and parameter index. - // This is used later on when the autograd graph is traversed - // to check for parameters for which no gradient is computed, if - // find_unused_parameters=True. - // Note that the mapping of gradient accumulator to variable should be - // one to one as we deduplicate shared parameters before constructing - // Reducer. - if (find_unused_parameters_) { - gradAccToVariableMap_[grad_accumulator.get()] = index; - } + this->autograd_hook(variable_index); + return outputs; + })), + grad_accumulator); + + // Map raw function pointer to replica index and parameter index. + // This is used later on when the autograd graph is traversed + // to check for parameters for which no gradient is computed, if + // find_unused_parameters=True. + // Note that the mapping of gradient accumulator to variable should be + // one to one as we deduplicate shared parameters before constructing + // Reducer. + if (find_unused_parameters_) { + gradAccToVariableMap_[grad_accumulator.get()] = variable_index; + } - numGradHooksTriggeredMap_[index] = 0; + numGradHooksTriggeredMap_[variable_index] = 0; - // The gradient accumulator is stored as weak_ptr in the autograd - // metadata of the variable, so we have to keep it alive here for - // the raw pointer to be valid. - TORCH_CHECK( - grad_accumulators_[replica_index][variable_index] == nullptr, - c10::str( - "Reducer tried to register duplicate grad accumulator for replica ", - replica_index, - " variable ", - variable_index)); - grad_accumulators_[replica_index][variable_index] = - std::move(grad_accumulator); - } + // The gradient accumulator is stored as weak_ptr in the autograd + // metadata of the variable, so we have to keep it alive here for + // the raw pointer to be valid. + TORCH_CHECK( + grad_accumulators_[replica_index][variable_index] == nullptr, + c10::str( + "Reducer tried to register duplicate grad accumulator for replica ", + replica_index, + " variable ", + variable_index)); + grad_accumulators_[replica_index][variable_index] = + std::move(grad_accumulator); } } @@ -300,9 +298,8 @@ void Reducer::copy_grad_to_bucket( } } -void Reducer::mark_variable_ready_dense(VariableIndex index) { - const auto replica_index = index.replica_index; - const auto variable_index = index.variable_index; +void Reducer::mark_variable_ready_dense(size_t variable_index) { + const auto replica_index = 0; const auto& bucket_index = variable_locators_[variable_index]; auto& bucket = buckets_[bucket_index.bucket_index]; auto& replica = bucket.replicas[replica_index]; @@ -343,7 +340,7 @@ void Reducer::mark_variable_ready_dense(VariableIndex index) { // instead of not touching .grad field of parameter. if (this->dynamic_graph_find_unused() || this->static_graph_first_iteration()) { TORCH_CHECK( - local_used_maps_[index.replica_index][index.variable_index] + local_used_maps_[0][variable_index] .item() == 0, "Encountered gradient which is undefined, but still allreduced by DDP reducer. This indicates a bug in DDP implementation, please report a bug with a repro to PyTorch."); } @@ -354,9 +351,8 @@ void Reducer::mark_variable_ready_dense(VariableIndex index) { }); } -void Reducer::mark_variable_ready_sparse(VariableIndex index) { - const auto replica_index = index.replica_index; - const auto variable_index = index.variable_index; +void Reducer::mark_variable_ready_sparse(size_t variable_index) { + const auto replica_index = 0; const auto& bucket_index = variable_locators_[variable_index]; auto& bucket = buckets_[bucket_index.bucket_index]; auto& replica = bucket.replicas[replica_index]; @@ -417,21 +413,16 @@ void Reducer::push_rebuilt_params_for_all_indices() { return; } const auto replica_count = replicas_.size(); - for (size_t replica_index = 0; replica_index < replica_count; - ++replica_index) { - const auto variable_count = replicas_[replica_index].size(); - for (size_t variable_index = 0; variable_index < variable_count; - ++variable_index) { - const auto index = VariableIndex(replica_index, variable_index); - push_rebuilt_params(index); - } + const auto variable_count = replicas_[0].size(); + for (size_t variable_index = 0; variable_index < variable_count; + ++variable_index) { + push_rebuilt_params(variable_index); } } -void Reducer::push_rebuilt_params(const VariableIndex& index) { - rebuilt_params_.push_back( - replicas_[index.replica_index][index.variable_index]); - rebuilt_param_indices_.push_back(index.variable_index); +void Reducer::push_rebuilt_params(const size_t& index) { + rebuilt_params_.push_back(replicas_[0][index]); + rebuilt_param_indices_.push_back(index); } void Reducer::set_divide_factor() { @@ -455,7 +446,7 @@ const c10::Stream Reducer::get_current_stream() { const auto& device = replicas_[0][0].device(); c10::DeviceType deviceType = device.type(); const c10::impl::VirtualGuardImpl guard = - c10::impl::VirtualGuardImpl{deviceType}; + c10::impl::VirtualGuardImpl{deviceType}; return guard.getStream(device); } @@ -483,27 +474,25 @@ void Reducer::delay_all_reduce() { unused_parameters_.clear(); // copy all gradients to buckets - for (size_t replica_index = 0; replica_index < replicas_.size(); - replica_index++) { - for (size_t variable_index = 0; variable_index < replicas_[replica_index].size(); - variable_index++) { - const auto index = VariableIndex(replica_index, variable_index); - // set unused_parameters_ - if (numGradHooksTriggeredMap_[index] == 0) { - unused_parameters_.push_back(index); - } - require_finalize_ = true; - set_divide_factor(); - if (expect_sparse_gradients_[replica_index][variable_index]) { - mark_variable_ready_sparse(index); - } else { - mark_variable_ready_dense(index); - } + size_t replica_index = 0; + for (size_t variable_index = 0; + variable_index < replicas_[replica_index].size(); + variable_index++) { + // set unused_parameters_ + if (numGradHooksTriggeredMap_[variable_index] == 0) { + unused_parameters_.push_back(variable_index); + } + require_finalize_ = true; + set_divide_factor(); + if (expect_sparse_gradients_[replica_index][variable_index]) { + mark_variable_ready_sparse(variable_index); + } else { + mark_variable_ready_dense(variable_index); } } // launch all reduces for all buckets - for (auto & bucket : buckets_) { + for (auto& bucket : buckets_) { all_reduce_bucket(bucket); } @@ -513,7 +502,7 @@ void Reducer::delay_all_reduce() { // The function `autograd_hook` is called after the gradient for a // model parameter has been accumulated into its gradient tensor. // This function is only to be called from the autograd thread. -void Reducer::autograd_hook(VariableIndex index) { +void Reducer::autograd_hook(size_t index) { std::lock_guard lock(this->mutex_); // Carry over thread local state from main thread. This allows for @@ -536,11 +525,10 @@ void Reducer::autograd_hook(VariableIndex index) { // Only set it as locally used if the grad is defined. Otherwise, hooks // could sometimes be triggered with undefined grads, and if this happens // globally, we don't want to touch the .grad field of the param. - auto& variable = get_param_from_index(index); runGradCallbackForVariable(variable, [&](auto& grad) { if (grad.defined()) { - local_used_maps_[index.replica_index][index.variable_index] = 1; + local_used_maps_[0][index] = 1; } // The gradient is never modified. return false; @@ -564,7 +552,15 @@ void Reducer::autograd_hook(VariableIndex index) { } } - // If it is static graph, after 1st iteration, check a avariable + // Rebuild bucket only if 1) it is the first time to rebuild bucket 2) + // static_graph_ is true or find_unused_parameters_ is false, + // 3) this backward pass needs to run allreduce. + // Here, we just dump tensors and their parameter indices into + // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving + // order, and then at the end of finalize_backward(), buckets will be + // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then + // will be broadcasted and initialized. + // If it is static graph, after 1st iteration, check if a variable // is ready for communication based on numGradHooksTriggeredMap_. if (static_graph_after_first_iteration()) { TORCH_CHECK( @@ -591,95 +587,91 @@ void Reducer::autograd_hook(VariableIndex index) { void Reducer::all_reduce_local_used_map() { // See Note [Skip allreducing local_used_maps_dev] - // H2D from local_used_maps_ to local_used_maps_dev_ - for (size_t i = 0; i < local_used_maps_.size(); i++) { - if (local_used_maps_dev_[i].is_cuda()) { - // Note [local_used_maps_ -> local_used_maps_dev copying] - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // We do async H2D to avoid the blocking overhead. The async copy and - // allreduce respect the current stream, so will be sequenced - // correctly. - // - // Correct sequencing with respect to host operations is also - // essential. The H2D copy_ is stream ordered, while the host's - // changes to local_used_maps_ are host ordered. If a large backlog of - // cuda-stream work pushes the copy_ far into the future, and if no - // blocking calls occur between now and finalize_backward()** such - // that finalize_backward() re-zeroes local_used_maps_ on the host - // before the stream executes the copy_, copy_ will read those zeros - // instead of the values we thought we told it to read here. Copying - // local_used_maps_[i] to a pinned temporary (which the pinned caching - // allocator should supply asynchronously) avoids this nasty, rare - // race condition. - // - // ** In the hoped-for case where all params are used, DDP itself - // won't do any blocking work between now and the re-zeroing, so the - // danger is real. - // - // Defensively ensures local_used_maps_tmp is distinct from - // local_used_maps_[i] - auto local_used_maps_tmp = at::native::empty_like( - local_used_maps_[i], - optTypeMetaToScalarType(local_used_maps_[i].options().dtype_opt()), - local_used_maps_[i].options().layout_opt(), - local_used_maps_[i].options().device_opt(), - true /* pinned_memory */); - // Paranoid asserts here because in some workloads, the pinned - // allocator behaves in a way we don't understand, and may be bugged. - // See https://github.com/pytorch/pytorch/pull/54474 - TORCH_INTERNAL_ASSERT(local_used_maps_tmp.is_pinned()); - TORCH_INTERNAL_ASSERT( - local_used_maps_tmp.data_ptr() != local_used_maps_[i].data_ptr()); - local_used_maps_tmp.copy_(local_used_maps_[i]); - local_used_maps_dev_[i].copy_(local_used_maps_tmp, true); - } else { - local_used_maps_dev_[i].copy_(local_used_maps_[i], true); - } + // H2D from local_used_maps_ to local_used_maps_dev_ + for (size_t i = 0; i < local_used_maps_.size(); i++) { + if (local_used_maps_dev_[i].is_cuda()) { + // Note [local_used_maps_ -> local_used_maps_dev copying] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // We do async H2D to avoid the blocking overhead. The async copy and + // allreduce respect the current stream, so will be sequenced + // correctly. + // + // Correct sequencing with respect to host operations is also + // essential. The H2D copy_ is stream ordered, while the host's + // changes to local_used_maps_ are host ordered. If a large backlog of + // cuda-stream work pushes the copy_ far into the future, and if no + // blocking calls occur between now and finalize_backward()** such + // that finalize_backward() re-zeroes local_used_maps_ on the host + // before the stream executes the copy_, copy_ will read those zeros + // instead of the values we thought we told it to read here. Copying + // local_used_maps_[i] to a pinned temporary (which the pinned caching + // allocator should supply asynchronously) avoids this nasty, rare + // race condition. + // + // ** In the hoped-for case where all params are used, DDP itself + // won't do any blocking work between now and the re-zeroing, so the + // danger is real. + // + // Defensively ensures local_used_maps_tmp is distinct from + // local_used_maps_[i] + auto local_used_maps_tmp = at::native::empty_like( + local_used_maps_[i], + optTypeMetaToScalarType(local_used_maps_[i].options().dtype_opt()), + local_used_maps_[i].options().layout_opt(), + local_used_maps_[i].options().device_opt(), + true /* pinned_memory */); + // Paranoid asserts here because in some workloads, the pinned + // allocator behaves in a way we don't understand, and may be bugged. + // See https://github.com/pytorch/pytorch/pull/54474 + TORCH_INTERNAL_ASSERT(local_used_maps_tmp.is_pinned()); + TORCH_INTERNAL_ASSERT( + local_used_maps_tmp.data_ptr() != local_used_maps_[i].data_ptr()); + local_used_maps_tmp.copy_(local_used_maps_[i]); + local_used_maps_dev_[i].copy_(local_used_maps_tmp, true); + } else { + local_used_maps_dev_[i].copy_(local_used_maps_[i], true); } - local_used_work_ = process_group_->allreduce(local_used_maps_dev_); + } + local_used_work_ = process_group_->allreduce(local_used_maps_dev_); } -at::Tensor& Reducer::get_param_from_index(VariableIndex index) { - const auto replica_index = index.replica_index; - const auto variable_index = index.variable_index; - const auto& bucket_index = variable_locators_[variable_index]; +at::Tensor& Reducer::get_param_from_index(size_t index) { + const auto& bucket_index = variable_locators_[index]; auto& bucket = buckets_[bucket_index.bucket_index]; - auto& replica = bucket.replicas[replica_index]; + auto& replica = bucket.replicas[0]; // Cannot simply access variable via replicas_[replica_index][variable_index] // as the callback does not accept const tensors. auto& variable = replica.variables[bucket_index.intra_bucket_index]; return variable; } -void Reducer::checkAndRaiseMarkedTwiceError(size_t curVariableIndex) { +void Reducer::checkAndRaiseMarkedTwiceError(size_t index) { // Something is wrong if all variables contained in this bucket replica have // already been marked as ready. // We don't expect the same variable to be marked ready twice. - bool marked_twice = perIterationReadyParams_.find(curVariableIndex) != perIterationReadyParams_.end(); + bool marked_twice = + perIterationReadyParams_.find(index) != perIterationReadyParams_.end(); if (marked_twice) { // Report index of param that has been marked twice. In debug mode, also // report fully qualified parameter name. - auto param_name = param_names_.find(curVariableIndex); + auto param_name = param_names_.find(index); const bool found_param_name = param_name != param_names_.end(); TORCH_INTERNAL_ASSERT( - ddp_debug_level_ == c10d::DistributedDebugLevel::OFF - || found_param_name, - "Expected to find parameter name in debug mode." - ); + ddp_debug_level_ == c10d::DistributedDebugLevel::OFF || + found_param_name, + "Expected to find parameter name in debug mode."); std::string paramInfo = c10::str( "Parameter at index ", - curVariableIndex, + index, found_param_name ? c10::str(" with name ", param_name->second) : "", " has been marked as ready twice. This means that multiple autograd engine ", - " hooks have fired for this particular parameter during this iteration." - ); + " hooks have fired for this particular parameter during this iteration."); // param_names_ is empty in debug mode. if (!found_param_name) { paramInfo += c10::str( - " You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either", - " INFO or DETAIL to print parameter names for further debugging." - ); + " You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either", + " INFO or DETAIL to print parameter names for further debugging."); } std::string common_error = c10::str( "Expected to mark a variable ready only once. ", @@ -722,28 +714,14 @@ void Reducer::checkAndRaiseMarkedTwiceError(size_t curVariableIndex) { } } -void Reducer::mark_variable_ready(VariableIndex index) { - // Rebuild bucket only if 1) it is the first time to rebuild bucket 2) - // static_graph_ is true or find_unused_parameters_ is false, - // 3) this backward pass needs to run allreduce. - // Here, we just dump tensors and their parameter indices into - // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving - // order, and then at the end of finalize_backward(), buckets will be - // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then - // will be broadcasted and initialized. Also we only need to dump tensors - // and parameter indices of one replica. - const auto replica_index = index.replica_index; - const auto variable_index = index.variable_index; - TORCH_CHECK(replica_index < replicas_.size(), "Out of range replica index."); +void Reducer::mark_variable_ready(size_t variable_index) { TORCH_CHECK( variable_index < variable_locators_.size(), "Out of range variable index."); - if (replica_index == 0) { - checkAndRaiseMarkedTwiceError(variable_index); - perIterationReadyParams_.insert(variable_index); - } - backward_stats_[replica_index][variable_index] = + checkAndRaiseMarkedTwiceError(variable_index); + perIterationReadyParams_.insert(variable_index); + backward_stats_[0][variable_index] = current_time_in_nanos() - cpu_timer_.backward_compute_start_time; // Any time we mark a variable ready (be it in line due to unused parameters, @@ -754,15 +732,14 @@ void Reducer::mark_variable_ready(VariableIndex index) { const auto& bucket_index = variable_locators_[variable_index]; auto& bucket = buckets_[bucket_index.bucket_index]; - auto& replica = bucket.replicas[replica_index]; - + auto& replica = bucket.replicas[0]; set_divide_factor(); if (bucket.expect_sparse_gradient) { - mark_variable_ready_sparse(index); + mark_variable_ready_sparse(variable_index); } else { - mark_variable_ready_dense(index); + mark_variable_ready_dense(variable_index); } // TODO(@pietern): Make this work for both CPU/CUDA tensors. @@ -782,7 +759,6 @@ void Reducer::mark_variable_ready(VariableIndex index) { // Run finalizer function and kick off reduction for local_used_maps once the // final bucket was marked ready. if (next_bucket_ == buckets_.size()) { - if (dynamic_graph_find_unused()) { all_reduce_local_used_map(); } @@ -920,105 +896,102 @@ void Reducer::initialize_buckets( } // Iterate over model replicas. - for (size_t replica_index = 0; replica_index < replica_count; - replica_index++) { - BucketReplica replica; - - if (bucket.expect_sparse_gradient) { - const auto variable_index = bucket_indices[bucket_index].front(); + BucketReplica replica; + size_t replica_index = 0; + if (bucket.expect_sparse_gradient) { + const auto variable_index = bucket_indices[bucket_index].front(); + const auto& variable = replicas_[replica_index][variable_index]; + TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1); + replica.variables = {variable}; + } else { + at::TensorOptions options; + // The start index of the variable in the flattened tensor. + size_t offset = 0; + + // Reserve enough space for the per-variable fields stored in bucket + // replica for efficiency. + const size_t num_variables = bucket_indices[bucket_index].size(); + replica.variables.reserve(num_variables); + replica.offsets.reserve(num_variables); + replica.lengths.reserve(num_variables); + replica.sizes_vec.reserve(num_variables); + + // Iterate over bucket variables. + for (const auto variable_index : bucket_indices[bucket_index]) { + TORCH_CHECK( + variable_index < replicas_[replica_index].size(), + "Out of range variable index specified."); const auto& variable = replicas_[replica_index][variable_index]; - TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1); - replica.variables = {variable}; - } else { - at::TensorOptions options; - // The start index of the variable in the flattened tensor. - size_t offset = 0; - - // Reserve enough space for the per-variable fields stored in bucket - // replica for efficiency. - const size_t num_variables = bucket_indices[bucket_index].size(); - replica.variables.reserve(num_variables); - replica.offsets.reserve(num_variables); - replica.lengths.reserve(num_variables); - replica.sizes_vec.reserve(num_variables); - - // Iterate over bucket variables. - for (const auto variable_index : bucket_indices[bucket_index]) { + if (!options.has_device()) { + options = options.device(variable.device()); + } else { TORCH_CHECK( - variable_index < replicas_[replica_index].size(), - "Out of range variable index specified."); - const auto& variable = replicas_[replica_index][variable_index]; - if (!options.has_device()) { - options = options.device(variable.device()); - } else { - TORCH_CHECK( - variable.device() == options.device(), - "All parameters in a bucket must be ", - "placed on the same device."); - } - if (!options.has_dtype()) { - options = options.dtype(variable.dtype()); - } else { - TORCH_CHECK( - variable.dtype() == options.dtype(), - "All parameters in a bucket must have the same dtype."); - } - const auto length = variable.numel(); - replica.variables.push_back(variable); - replica.offsets.push_back(offset); - replica.lengths.push_back(length); - replica.sizes_vec.push_back(variable.sizes()); - offset += length; + variable.device() == options.device(), + "All parameters in a bucket must be ", + "placed on the same device."); } - - // Allocate bucket contents tensor. - replica.contents = at::empty({static_cast(offset)}, options); - - // Note: "Gradient Layout Contract" - // - // Here, create views into the contents tensor for each variable's grad. - // Views serve as entry points to copy_ each grad's data in/out of the - // flat contents tensor. - // - // Gradients may have dense memory but non-row-major-contiguous strides - // (e.g. channels_last or channels_last_3d). For coalesced accesses - // during copy_s, it's beneficial for each view's layout to match its - // grad's layout. - // - // Specifically, we expect torch/csrc/autograd/AccumulateGrad.h produces - // grads that obey there "Gradient Layout Contract": - // (1) if variable.is_non_overlapping_and_dense(), the stashed grad's - // strides match variable. - // (2) else, stashed grad is rowmajor contiguous. - // and create views to match. - // - // If AccumulateGrad breaks the contract, and produces a grad with an - // unexpected layout, performance will degrade due to poor memory access - // patterns when copy_ing grad data in and out of its bucket view. - // However, numerics remain correct, because the bucket view is the same - // on either end of the raw allreduce. bucket_view_in.copy(grad) - // tranposes - // (+ densifies) to the bucket view's layout, the data is allreduced, - // then grad.copy_(bucket_view_out) transposes it back to grad's layout. - // - // The only way the numerics can go haywire is if the bucket views - // themselves have different layouts across processes (or replicas). - // Bucket views' sizes and strides are set based on param layouts, using - // the same logic that (we expect) AccumulateGrad uses for their grads. - // Therefore, the only way a bucket view could have different layouts in - // different processes is if its param has a different layout in - // different processes. We can check that param layouts match across - // processes and replicas in Reducer's constructor by allreducing some - // metadata. Checking just once won't catch if someone messes with - // param layouts over time, but not messing with params after DDP - // construction is already a documented constraint. - initialize_bucket_views(replica, replica.contents); + if (!options.has_dtype()) { + options = options.dtype(variable.dtype()); + } else { + TORCH_CHECK( + variable.dtype() == options.dtype(), + "All parameters in a bucket must have the same dtype."); + } + const auto length = variable.numel(); + replica.variables.push_back(variable); + replica.offsets.push_back(offset); + replica.lengths.push_back(length); + replica.sizes_vec.push_back(variable.sizes()); + offset += length; } - // Add bucket replica to enclosing bucket. - bucket.replicas.push_back(std::move(replica)); + // Allocate bucket contents tensor. + replica.contents = at::empty({static_cast(offset)}, options); + + // Note: "Gradient Layout Contract" + // + // Here, create views into the contents tensor for each variable's grad. + // Views serve as entry points to copy_ each grad's data in/out of the + // flat contents tensor. + // + // Gradients may have dense memory but non-row-major-contiguous strides + // (e.g. channels_last or channels_last_3d). For coalesced accesses + // during copy_s, it's beneficial for each view's layout to match its + // grad's layout. + // + // Specifically, we expect torch/csrc/autograd/AccumulateGrad.h produces + // grads that obey there "Gradient Layout Contract": + // (1) if variable.is_non_overlapping_and_dense(), the stashed grad's + // strides match variable. + // (2) else, stashed grad is rowmajor contiguous. + // and create views to match. + // + // If AccumulateGrad breaks the contract, and produces a grad with an + // unexpected layout, performance will degrade due to poor memory access + // patterns when copy_ing grad data in and out of its bucket view. + // However, numerics remain correct, because the bucket view is the same + // on either end of the raw allreduce. bucket_view_in.copy(grad) + // tranposes + // (+ densifies) to the bucket view's layout, the data is allreduced, + // then grad.copy_(bucket_view_out) transposes it back to grad's layout. + // + // The only way the numerics can go haywire is if the bucket views + // themselves have different layouts across processes (or replicas). + // Bucket views' sizes and strides are set based on param layouts, using + // the same logic that (we expect) AccumulateGrad uses for their grads. + // Therefore, the only way a bucket view could have different layouts in + // different processes is if its param has a different layout in + // different processes. We can check that param layouts match across + // processes and replicas in Reducer's constructor by allreducing some + // metadata. Checking just once won't catch if someone messes with + // param layouts over time, but not messing with params after DDP + // construction is already a documented constraint. + initialize_bucket_views(replica, replica.contents); } + // Add bucket replica to enclosing bucket. + bucket.replicas.push_back(std::move(replica)); + // Map participating variables to this bucket. // This is identical across replicas so we only need to do this once. size_t intra_bucket_index = 0; @@ -1266,8 +1239,10 @@ std::vector Reducer::getUnmarkedParamsForIteration() { std::vector Reducer::getUnmarkedParamIndicesForIteration() { std::vector unmarked_param_indices; const auto variable_count = replicas_[0].size(); - for (size_t variable_index = 0; variable_index < variable_count; variable_index++) { - if (perIterationReadyParams_.find(variable_index) == perIterationReadyParams_.end()) { + for (size_t variable_index = 0; variable_index < variable_count; + variable_index++) { + if (perIterationReadyParams_.find(variable_index) == + perIterationReadyParams_.end()) { unmarked_param_indices.push_back(variable_index); } } @@ -1276,94 +1251,91 @@ std::vector Reducer::getUnmarkedParamIndicesForIteration() { // A bucket with one or more dense tensors needs to be unflattened. void Reducer::finalize_bucket_dense(Bucket& bucket) { - for (size_t replica_index = 0; replica_index < bucket.replicas.size(); - replica_index++) { - auto& replica = bucket.replicas[replica_index]; - for (size_t intra_bucket_index = 0; - intra_bucket_index < replica.variables.size(); - intra_bucket_index++) { - auto& variable = replica.variables[intra_bucket_index]; - const auto offset = replica.offsets[intra_bucket_index]; - const auto length = replica.lengths[intra_bucket_index]; - - bool global_unused = false; - // See Note [Skip allreducing local_used_maps_dev] - if (static_graph_ || find_unused_parameters_) { - // Determine if this param has been used globally or not. - // - // If the variable was used locally, it is also used globally and then - // we don't need to wait for the reduction. Otherwise we lazily wait for - // the reduction to complete, only when we see a variable that was - // unused locally. Then we end up delaying the synchronization point - // that local_used_work_->wait() implies. If we don't have any unused - // parameters at all, we can skip waiting for the work to complete - // altogether, and cause negligible performance overhead for models - // where all parameters are used. Such lazily waiting means minimizing - // performance impact for the big majority of models where all - // parameters are always used. Then we only pay the overhead cost if - // there is indeed a parameter that is locally unused, because we need - // to check if it's also globally unused. - size_t variable_index = bucket.variable_indices[intra_bucket_index]; - // Note: global_unused might not be global yet. As we lazily wait for - // the reduction to complete, it becomes really global only if we get to - // the point as below where we wait for the reduction work, make D2H - // copy, and update global_unused with the real global consensus, i.e. - // local_used_maps_reduced_ is true. + size_t replica_index = 0; + auto& replica = bucket.replicas[replica_index]; + for (size_t intra_bucket_index = 0; + intra_bucket_index < replica.variables.size(); + intra_bucket_index++) { + auto& variable = replica.variables[intra_bucket_index]; + const auto offset = replica.offsets[intra_bucket_index]; + const auto length = replica.lengths[intra_bucket_index]; + + bool global_unused = false; + // See Note [Skip allreducing local_used_maps_dev] + if (static_graph_ || find_unused_parameters_) { + // Determine if this param has been used globally or not. + // + // If the variable was used locally, it is also used globally and then + // we don't need to wait for the reduction. Otherwise we lazily wait for + // the reduction to complete, only when we see a variable that was + // unused locally. Then we end up delaying the synchronization point + // that local_used_work_->wait() implies. If we don't have any unused + // parameters at all, we can skip waiting for the work to complete + // altogether, and cause negligible performance overhead for models + // where all parameters are used. Such lazily waiting means minimizing + // performance impact for the big majority of models where all + // parameters are always used. Then we only pay the overhead cost if + // there is indeed a parameter that is locally unused, because we need + // to check if it's also globally unused. + size_t variable_index = bucket.variable_indices[intra_bucket_index]; + // Note: global_unused might not be global yet. As we lazily wait for + // the reduction to complete, it becomes really global only if we get to + // the point as below where we wait for the reduction work, make D2H + // copy, and update global_unused with the real global consensus, i.e. + // local_used_maps_reduced_ is true. + global_unused = + local_used_maps_[replica_index][variable_index].item() == 0; + if (global_unused && !local_used_maps_reduced_) { + // Wait for local_used_maps reduction to complete. + local_used_work_->wait(); + // D2H from local_used_maps_dev_ to local_used_maps_ + for (size_t i = 0; i < local_used_maps_.size(); i++) { + // Blocking copy, if local_used_maps_dev_ is cuda + local_used_maps_[i].copy_(local_used_maps_dev_[i]); + } global_unused = local_used_maps_[replica_index][variable_index].item() == 0; - if (global_unused && !local_used_maps_reduced_) { - // Wait for local_used_maps reduction to complete. - local_used_work_->wait(); - // D2H from local_used_maps_dev_ to local_used_maps_ - for (size_t i = 0; i < local_used_maps_.size(); i++) { - // Blocking copy, if local_used_maps_dev_ is cuda - local_used_maps_[i].copy_(local_used_maps_dev_[i]); - } - global_unused = - local_used_maps_[replica_index][variable_index].item() == 0; - local_used_maps_reduced_ = true; - } + local_used_maps_reduced_ = true; } + } - if (!gradient_as_bucket_view_) { - copy_bucket_to_grad( - variable, replica, intra_bucket_index, global_unused); - } else { - const auto& bucket_view_out = - replica.bucket_views_out[intra_bucket_index]; - auto& bucket_view_in = replica.bucket_views_in[intra_bucket_index]; - // If communication_hook is registered, bucket_view_out stores - // allreduced results in a newly allocated tensor, copy bucket_view_out - // back to bucket_view_in that referring to replica.content tensor and - // grad. - if (!bucket_view_in.is_alias_of(bucket_view_out)) { - bucket_view_in.copy_(bucket_view_out); - } - runGradCallbackForVariable(variable, [&](auto& grad) { - // If a parameter is globally unused, we keep its grad untouched. - if (!global_unused) { - // If grad is globally used but locally unused, let grad point to - // bucket_view_in - if (!grad.defined()) { - grad = bucket_view_in; - } else { - if (!grad.is_alias_of(bucket_view_in)) { - TORCH_CHECK( - false, - "Detected at least one parameter gradient is not the " - "expected DDP bucket view with gradient_as_bucket_view=True. " - "This may happen (for example) if multiple allreduce hooks " - "were registered onto the same parameter. If you hit this error, " - "please file an issue with a minimal repro."); - } + if (!gradient_as_bucket_view_) { + copy_bucket_to_grad(variable, replica, intra_bucket_index, global_unused); + } else { + const auto& bucket_view_out = + replica.bucket_views_out[intra_bucket_index]; + auto& bucket_view_in = replica.bucket_views_in[intra_bucket_index]; + // If communication_hook is registered, bucket_view_out stores + // allreduced results in a newly allocated tensor, copy bucket_view_out + // back to bucket_view_in that referring to replica.content tensor and + // grad. + if (!bucket_view_in.is_alias_of(bucket_view_out)) { + bucket_view_in.copy_(bucket_view_out); + } + runGradCallbackForVariable(variable, [&](auto& grad) { + // If a parameter is globally unused, we keep its grad untouched. + if (!global_unused) { + // If grad is globally used but locally unused, let grad point to + // bucket_view_in + if (!grad.defined()) { + grad = bucket_view_in; + } else { + if (!grad.is_alias_of(bucket_view_in)) { + TORCH_CHECK( + false, + "Detected at least one parameter gradient is not the " + "expected DDP bucket view with gradient_as_bucket_view=True. " + "This may happen (for example) if multiple allreduce hooks " + "were registered onto the same parameter. If you hit this error, " + "please file an issue with a minimal repro."); } - // The grad is modified and needs to be written back. - return true; } - // The grad is not modified. - return false; - }); - } + // The grad is modified and needs to be written back. + return true; + } + // The grad is not modified. + return false; + }); } } } @@ -1641,21 +1613,21 @@ void Reducer::ensure_prior_reduction_finished() { // have been computed and reduction of all buckets has been kicked off. if (require_finalize_) { TORCH_CHECK( - !static_graph_, - "Expected to have finished reduction in the prior iteration before " - "starting a new one. " - "This error indicates that your training graph has changed ", - "in this iteration, e.g., one parameter is used in first ", - "iteration, but then got unused in the second iteration. ", - "this is not compatible with static_graph set to True." - ); + !static_graph_, + "Expected to have finished reduction in the prior iteration before " + "starting a new one. " + "This error indicates that your training graph has changed ", + "in this iteration, e.g., one parameter is used in first ", + "iteration, but then got unused in the second iteration. ", + "this is not compatible with static_graph set to True."); // Collect unmarked parameter indices, additionally, in debug mode retrieve // parameter names. auto unmarked_param_indices = getUnmarkedParamIndicesForIteration(); // We should have some unmarked parameter indices, otherwise we would not // have run into this error branch. TORCH_INTERNAL_ASSERT(unmarked_param_indices.size() > 0); - const std::string unmarkedParamIndices = c10::Join(", ", unmarked_param_indices); + const std::string unmarkedParamIndices = + c10::Join(", ", unmarked_param_indices); std::string kBaseErrorMsg = "Expected to have finished reduction in the prior iteration before " @@ -1699,18 +1671,17 @@ void Reducer::ensure_prior_reduction_finished() { "Parameter indices which did not receive grad for rank ", process_group_->getRank(), ": ", - unmarked_param_indices - ); + unmarked_param_indices); if (ddp_debug_level_ == DistributedDebugLevel::OFF) { - // Without debug mode, log unmarked_param_indices, as well as recommendation - // to use debug mode to print parameter names. + // Without debug mode, log unmarked_param_indices, as well as + // recommendation to use debug mode to print parameter names. kBaseErrorMsg += unmarked_param_indices_info; kBaseErrorMsg += - "\n In addition, you can set the environment variable " - "TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information " - "about which particular parameters did not receive gradient on this rank " - "as part of this error"; + "\n In addition, you can set the environment variable " + "TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information " + "about which particular parameters did not receive gradient on this rank " + "as part of this error"; } else { // Retrieve set of parameter names that did not receive gradient. auto unmarkedParams = getUnmarkedParamsForIteration(); @@ -1722,14 +1693,12 @@ void Reducer::ensure_prior_reduction_finished() { } const std::string unmarkedParamInfo = c10::Join(", ", unmarkedParams); // In debug mode, log param names and indices that went unused. - kBaseErrorMsg += - c10::str( + kBaseErrorMsg += c10::str( "\n", "Parameters which did not receive grad for rank ", process_group_->getRank(), ": ", - unmarkedParamInfo - ); + unmarkedParamInfo); kBaseErrorMsg += unmarked_param_indices_info; } TORCH_CHECK(false, kBaseErrorMsg); @@ -1756,9 +1725,8 @@ bool Reducer::should_collect_runtime_stats() { void Reducer::record_forward_compute_start_time() { if (replicas_[0][0].is_cuda()) { #ifdef USE_CUDA - // Record event only for single process single device - // and single device module. - if (replicas_.size() == 1 && !is_multi_device_module_) { + // Record event only for single device module. + if (!is_multi_device_module_) { // Create and record event on the replicas_[0][0].device(). at::DeviceGuard g(replicas_[0][0].device()); gpu_timer_.forward_start.record(); @@ -1772,9 +1740,8 @@ void Reducer::record_forward_compute_start_time() { void Reducer::record_backward_compute_start_time() { if (replicas_[0][0].is_cuda()) { #ifdef USE_CUDA - // Record event only for single process single device - // and single device module. - if (replicas_.size() == 1 && !is_multi_device_module_) { + // Record event only for single device module. + if (!is_multi_device_module_) { // Create and record event on the replicas_[0][0].device(). at::DeviceGuard g(replicas_[0][0].device()); gpu_timer_.backward_compute_start.record(); @@ -1786,9 +1753,8 @@ void Reducer::record_backward_compute_start_time() { void Reducer::record_backward_compute_end_time() { if (replicas_[0][0].is_cuda()) { #ifdef USE_CUDA - // Record event only for single process single device - // and single device module. - if (replicas_.size() == 1 && !is_multi_device_module_) { + // Record event only for single device module. + if (!is_multi_device_module_) { at::DeviceGuard g(replicas_[0][0].device()); gpu_timer_.backward_compute_end.record(); } @@ -1801,9 +1767,8 @@ void Reducer::record_backward_compute_end_time() { void Reducer::record_backward_comm_start_time() { if (replicas_[0][0].is_cuda()) { #ifdef USE_CUDA - // Record event only for single process single device - // and single device module. - if (replicas_.size() == 1 && !is_multi_device_module_) { + // Record event only for single device module + if (!is_multi_device_module_) { at::DeviceGuard g(replicas_[0][0].device()); gpu_timer_.backward_comm_start.record(); } @@ -1816,9 +1781,8 @@ void Reducer::record_backward_comm_start_time() { void Reducer::record_backward_comm_end_time() { if (replicas_[0][0].is_cuda()) { #ifdef USE_CUDA - // Record event only for single process single device - // and single device module. - if (replicas_.size() == 1 && !is_multi_device_module_) { + // Record event only for single device module. + if (!is_multi_device_module_) { at::DeviceGuard g(replicas_[0][0].device()); gpu_timer_.backward_comm_end.record(); } diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index d6dd0b0d95fc5f..43d408cb23bba6 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -12,8 +12,8 @@ #include #include -#include #include +#include #include #include #include @@ -26,28 +26,6 @@ constexpr int kDefaultBucketBytesCap = int(25 * 1024 * 1024); // Collect runtime stats once for every kDDPRuntimeLoggingSampleRate iterations. constexpr int kDDPRuntimeLoggingSampleRate = 100; -// Locates a specific variable by replica index and variable index. -struct VariableIndex { - size_t replica_index; - size_t variable_index; - - VariableIndex() = default; - - VariableIndex(size_t replica_index_, size_t variable_index_) { - replica_index = replica_index_; - variable_index = variable_index_; - } - - static size_t hash(const VariableIndex& key) { - return c10::get_hash(key.replica_index, key.variable_index); - } -}; - -inline bool operator==(const VariableIndex& lhs, const VariableIndex& rhs) { - return lhs.replica_index == rhs.replica_index - && lhs.variable_index == rhs.variable_index; -} - class Reducer { public: // The constructor takes a list of variables for every model replica. @@ -62,8 +40,7 @@ class Reducer { int64_t bucket_bytes_cap, bool find_unused_parameters, bool gradient_as_bucket_view, - std::unordered_map - paramNames); + std::unordered_map paramNames); ~Reducer() noexcept(false); @@ -157,7 +134,7 @@ class Reducer { // Forward declaration. struct Bucket; - void push_rebuilt_params(const VariableIndex& index); + void push_rebuilt_params(const size_t& index); mutable std::mutex mutex_; const std::vector> replicas_; @@ -166,8 +143,7 @@ class Reducer { std::vector>> grad_accumulators_; - std::unordered_map - gradAccToVariableMap_; + std::unordered_map gradAccToVariableMap_; std::vector>> hooks_; @@ -178,7 +154,7 @@ class Reducer { bool has_marked_unused_parameters_; const bool find_unused_parameters_; const bool gradient_as_bucket_view_; - std::vector unused_parameters_; + std::vector unused_parameters_; // Locally used parameter maps indicating if parameters are used locally // during the current iteration or no_sync session if no_sync is on. One // tensor for each model replica and each tensor is one-dim int32 tensor of @@ -197,13 +173,13 @@ class Reducer { // Work handle for allreduce on local_used_maps_ c10::intrusive_ptr local_used_work_; - void mark_variable_ready_dense(VariableIndex index); + void mark_variable_ready_dense(size_t variable_index); - void mark_variable_ready_sparse(VariableIndex index); + void mark_variable_ready_sparse(size_t variable_index); - void mark_variable_ready(VariableIndex index); + void mark_variable_ready(size_t variable_index); - void autograd_hook(VariableIndex index); + void autograd_hook(size_t index); void mark_bucket_ready(size_t bucket_index); @@ -437,14 +413,15 @@ class Reducer { bool static_graph_; - // Key: VariableIndex, Value: the number of times that a variable's autograd_hook() - // should be triggered before marking this variable's grad as ready for communication. - // Map will not change after 1st iteration. - std::unordered_map> numGradHooksTriggeredMap_; - // Key: VariableIndex, Value: the number of times that a variable's autograd_hook() - // are left to be triggered before marking this variable's grad as ready for communication. - // Map will change after 1st iteration to track a grad is ready for communication or not. - std::unordered_map> numGradHooksTriggeredMapPerIteration_; + // Key: size_t (index), Value: the number of times that a variable's + // autograd_hook() should be triggered before marking this variable's grad as + // ready for communication. Map will not change after 1st iteration. + std::unordered_map numGradHooksTriggeredMap_; + // Key: size_t (index), Value: the number of times that a variable's + // autograd_hook() are left to be triggered before marking this variable's + // grad as ready for communication. Map will change after 1st iteration to + // track a grad is ready for communication or not. + std::unordered_map numGradHooksTriggeredMapPerIteration_; private: // reset counting for buckets before backward starts @@ -455,7 +432,8 @@ class Reducer { void set_divide_factor(); // kick off all reduce for the ready bucket void all_reduce_bucket(Bucket& bucket); - // kick off all reduce to local used map, it can help find global unused parameters + // kick off all reduce to local used map, it can help find global unused + // parameters void all_reduce_local_used_map(); // initialize locally used parameter maps void initialize_local_used_map(); @@ -487,7 +465,7 @@ class Reducer { // variable twice, which is unexpected. void checkAndRaiseMarkedTwiceError(size_t curVariableIndex); // Retrieves parameter corresponding to the given VariableIndex. - at::Tensor& get_param_from_index(VariableIndex index); + at::Tensor& get_param_from_index(size_t index); friend class Logger; }; diff --git a/torch/nn/functional.py b/torch/nn/functional.py index bcbd17f6d1632c..db5fe5960097d5 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1898,6 +1898,25 @@ def silu(input: Tensor, inplace: bool = False) -> Tensor: return torch._C._nn.silu(input) +def mish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Applies the Mish function, element-wise. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + See :class:`~torch.nn.Mish` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(mish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.mish_(input) + return torch._C._nn.mish(input) + + def hardswish(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the hardswish function, element-wise, as described in the paper: @@ -4248,7 +4267,7 @@ def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6, one_hot = _add_docstr( torch._C._nn.one_hot, r""" -one_hot(tensor, num_classes=-1) -> LongTensor +one_hot(tensor, num_classes=-1, dtype=torch.long) -> LongTensor Takes LongTensor with index values of shape ``(*)`` and returns a tensor of shape ``(*, num_classes)`` that have zeros everywhere except where the @@ -4265,6 +4284,8 @@ def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6, num_classes (int): Total number of classes. If set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor. + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Default: ``torch.int64``. Returns: LongTensor that has one more dimension with 1 values at the diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index a7c5d3a3fa2abd..828f8df2185b52 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -182,6 +182,7 @@ def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tens def silu(input: Tensor, inplace: bool = False) -> Tensor: ... +def mish(input: Tensor, inplace: bool = False) -> Tensor: ... def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ... diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index cc525b35fa1657..c566d2b87aa9e2 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -6,7 +6,7 @@ from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \ Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \ - Hardsigmoid, Hardswish, SiLU + Hardsigmoid, Hardswish, SiLU, Mish from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \ CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \ MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, HuberLoss, \ @@ -60,5 +60,5 @@ 'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', 'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d', - 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' + 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' ] diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 09ce1dd0fec310..f0d04441e9be59 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -27,15 +27,31 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) - # world_size * (2C + 1) - combined_list = [ - torch.empty_like(combined) for k in range(world_size) - ] - # Use allgather instead of allreduce since I don't trust in-place operations .. - dist.all_gather(combined_list, combined, process_group, async_op=False) - combined = torch.stack(combined_list, dim=0) - # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 - mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + # Use allgather instead of allreduce because count could be different across + # ranks, simple all reduce op can not give correct results. + # batch_norm_gather_stats_with_counts calculates global mean & invstd based on + # all gathered mean, invstd and count. + # for nccl backend, use the optimized version of all gather. + if process_group._get_backend_name() == 'nccl': + # world_size * (2C + 1) + combined_size = combined.numel() + combined_flat = torch.empty(1, + combined_size * world_size, + dtype=combined.dtype, + device=combined.device) + dist._all_gather_base(combined_flat, combined, process_group, async_op=False) + combined = torch.reshape(combined_flat, (world_size, combined_size)) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + else: + # world_size * (2C + 1) + combined_list = [ + torch.empty_like(combined) for k in range(world_size) + ] + dist.all_gather(combined_list, combined, process_group, async_op=False) + combined = torch.stack(combined_list, dim=0) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index d4485f203286e2..54deaa0b499f75 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -398,6 +398,41 @@ def extra_repr(self) -> str: inplace_str = 'inplace=True' if self.inplace else '' return inplace_str +class Mish(Module): + r"""Applies the Mish function, element-wise. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Examples:: + + >>> m = nn.Mish() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] + inplace: bool + + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.mish(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str + class Hardswish(Module): r"""Applies the hardswish function, element-wise, as described in the paper: diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index be05e8d73032d0..ba1e6e0d09bdc7 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -138,6 +138,9 @@ def __init__(self, self.reset_parameters() def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 5aa97c93e15605..3739bb2c8848c4 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1651,6 +1651,9 @@ def eval(self: T) -> T: This is equivalent with :meth:`self.train(False) `. + See :ref:`locally-disable-grad-doc` for a comparison between + `.eval()` and several similar mechanisms that may be confused with it. + Returns: Module: self """ @@ -1666,6 +1669,9 @@ def requires_grad_(self: T, requires_grad: bool = True) -> T: This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training). + See :ref:`locally-disable-grad-doc` for a comparison between + `.requires_grad_()` and several similar mechanisms that may be confused with it. + Args: requires_grad (bool): whether autograd should record operations on parameters in this module. Default: ``True``. diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index ea338fbf020aba..7a6fde26ba9d99 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -373,50 +373,42 @@ class RNN(RNNBase): bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided - as `(batch, seq, feature)`. Default: ``False`` + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each RNN layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` Inputs: input, h_0 - - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features - of the input sequence. The input can also be a packed variable length - sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` - or :func:`torch.nn.utils.rnn.pack_sequence` - for details. - - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial hidden state for each element in the batch. - Defaults to zero if not provided. If the RNN is bidirectional, - num_directions should be 2, else it should be 1. - - Outputs: output, h_n - - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor - containing the output features (`h_t`) from the last layer of the RNN, - for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has - been given as the input, the output will also be a packed sequence. + * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for each element in the batch. Defaults to zeros if not provided. - For the unpacked case, the directions can be separated - using ``output.view(seq_len, batch, num_directions, hidden_size)``, - with forward and backward being direction `0` and `1` respectively. - Similarly, the directions can be separated in the packed case. - - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len`. + where: - Like *output*, the layers can be separated using - ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} - Shape: - - Input1: :math:`(L, N, H_{in})` tensor containing input features where - :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. - - Input2: :math:`(S, N, H_{out})` tensor - containing the initial hidden state for each element in the batch. - :math:`H_{out}=\text{hidden\_size}` - Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` - If the RNN is bidirectional, num_directions should be 2, else it should be 1. - - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` - - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state - for each element in the batch + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the RNN, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. Attributes: weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, @@ -433,6 +425,11 @@ class RNN(RNNBase): All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` + .. note:: + For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + .. include:: ../cudnn_rnn_determinism.rst .. include:: ../cudnn_persistent_rnn.rst @@ -518,7 +515,9 @@ class LSTM(RNNBase): bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 @@ -526,41 +525,40 @@ class LSTM(RNNBase): proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 Inputs: input, (h_0, c_0) - - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features - of the input sequence. - The input can also be a packed variable length sequence. + * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or :func:`torch.nn.utils.rnn.pack_sequence` for details. - - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial hidden state for each element in the batch. - If the LSTM is bidirectional, num_directions should be 2, else it should be 1. - If ``proj_size > 0`` was specified, the shape has to be - `(num_layers * num_directions, batch, proj_size)`. - - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial cell state for each element in the batch. - - If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. - + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the + initial hidden state for each element in the batch. + Defaults to zeros if (h_0, c_0) is not provided. + * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + initial cell state for each element in the batch. + Defaults to zeros if (h_0, c_0) is not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{cell} ={} & \text{hidden\_size} \\ + H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ + \end{aligned} Outputs: output, (h_n, c_n) - - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor - containing the output features `(h_t)` from the last layer of the LSTM, - for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been - given as the input, the output will also be a packed sequence. If ``proj_size > 0`` - was specified, output shape will be `(seq_len, batch, num_directions * proj_size)`. - - For the unpacked case, the directions can be separated - using ``output.view(seq_len, batch, num_directions, hidden_size)``, - with forward and backward being direction `0` and `1` respectively. - Similarly, the directions can be separated in the packed case. - - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len`. If ``proj_size > 0`` - was specified, ``h_n`` shape will be `(num_layers * num_directions, batch, proj_size)`. - - Like *output*, the layers can be separated using - ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*. - - **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the cell state for `t = seq_len`. + * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the LSTM, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the + final hidden state for each element in the batch. + * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + final cell state for each element in the batch. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer @@ -581,6 +579,11 @@ class LSTM(RNNBase): All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` + .. note:: + For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + .. include:: ../cudnn_rnn_determinism.rst .. include:: ../cudnn_persistent_rnn.rst @@ -724,49 +727,42 @@ class GRU(RNNBase): bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` Inputs: input, h_0 - - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features - of the input sequence. The input can also be a packed variable length - sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` - for details. - - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the initial hidden state for each element in the batch. - Defaults to zero if not provided. If the RNN is bidirectional, - num_directions should be 2, else it should be 1. - - Outputs: output, h_n - - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor - containing the output features h_t from the last layer of the GRU, - for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been - given as the input, the output will also be a packed sequence. - For the unpacked case, the directions can be separated - using ``output.view(seq_len, batch, num_directions, hidden_size)``, - with forward and backward being direction `0` and `1` respectively. + * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for each element in the batch. Defaults to zeros if not provided. - Similarly, the directions can be separated in the packed case. - - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len` + where: - Like *output*, the layers can be separated using - ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} - Shape: - - Input1: :math:`(L, N, H_{in})` tensor containing input features where - :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. - - Input2: :math:`(S, N, H_{out})` tensor - containing the initial hidden state for each element in the batch. - :math:`H_{out}=\text{hidden\_size}` - Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` - If the RNN is bidirectional, num_directions should be 2, else it should be 1. - - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` - - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state - for each element in the batch + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the GRU, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer @@ -783,6 +779,11 @@ class GRU(RNNBase): All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` + .. note:: + For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + .. include:: ../cudnn_persistent_rnn.rst Examples:: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index f12b7ebb5ecdaf..ac6fafb634f9d9 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -918,8 +918,6 @@ def gather(self, outputs, output_device): def train(self, mode=True): super(DistributedDataParallel, self).train(mode) - for module in self._module_copies[1:]: - module.train(mode) return self # When running in join mode, schedules an allreduce to match the one in the diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 7a2d1e4e839b77..df562e27382c67 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -18,7 +18,7 @@ class Parameter(torch.Tensor): Args: data (Tensor): parameter tensor. requires_grad (bool, optional): if the parameter requires gradient. See - :ref:`excluding-subgraphs` for more details. Default: `True` + :ref:`locally-disable-grad-doc` for more details. Default: `True` """ def __new__(cls, data=None, requires_grad=True): if data is None: diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index b108aebffb12c1..32ef82e561f39e 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -6,3 +6,4 @@ from .fusion import fuse_conv_bn_eval, fuse_conv_bn_weights from .memory_format import convert_conv2d_weight_memory_format from . import parametrizations +from .init import skip_init diff --git a/torch/nn/utils/init.py b/torch/nn/utils/init.py new file mode 100644 index 00000000000000..acb7d02156f9fa --- /dev/null +++ b/torch/nn/utils/init.py @@ -0,0 +1,51 @@ +import inspect +import torch + + +def skip_init(module_cls, *args, **kwargs): + r""" + Given a module class object and args / kwargs, instantiates the module without initializing + parameters / buffers. This can be useful if initialization is slow or if custom initialization will + be performed, making the default initialization unnecessary. There are some caveats to this, due to + the way this function is implemented: + + 1. The module must accept a `device` arg in its constructor that is passed to any parameters + or buffers created during construction. + + 2. The module must not perform any computation on parameters in its constructor except + initialization (i.e. functions from :mod:`torch.nn.init`). + + If these conditions are satisfied, the module can be instantiated with parameter / buffer values + uninitialized, as if having been created using :func:`torch.empty`. + + Args: + module_cls: Class object; should be a subclass of :class:`torch.nn.Module` + args: args to pass to the module's constructor + kwargs: kwargs to pass to the module's constructor + + Returns: + Instantiated module with uninitialized parameters / buffers + + Example:: + + >>> import torch + >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) + >>> m.weight + Parameter containing: + tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], + requires_grad=True) + >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) + >>> m2.weight + Parameter containing: + tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, + 4.5915e-41]], requires_grad=True) + + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError('Expected a Module; got {}'.format(module_cls)) + if 'device' not in inspect.signature(module_cls).parameters: + raise RuntimeError('Module must support a \'device\' arg to skip initialization') + + final_device = kwargs.pop('device', 'cpu') + kwargs['device'] = 'meta' + return module_cls(*args, **kwargs).to_empty(device=final_device) diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 8639742fff757b..2af83c7604d1f7 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -793,6 +793,8 @@ def _set_training_mode(training_mode): _training_mode = training_mode _onnx_shape_inference = False +# This function is for debug use only. +# onnx_shape_inference = True by default. def _set_onnx_shape_inference(onnx_shape_inference): global _onnx_shape_inference _onnx_shape_inference = onnx_shape_inference diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 3ec82d0e507266..c0aba4fb8b3594 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -869,7 +869,7 @@ def prim_ConstantChunk(g, self, chunks, dim): start = end return res -def repeat_interleave(g, self, repeats, dim=None): +def repeat_interleave(g, self, repeats, dim=None, output_size=None): from torch.onnx.symbolic_opset9 import reshape input = self final_dim = dim diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 0584fd9ac7b0e6..100641a61487dc 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -722,6 +722,8 @@ def prelu(g, self, weight): def silu(g, input): return g.op("Mul", input, g.op("Sigmoid", input)) +def mish(g, input): + return g.op('Mul', input, g.op('Tanh', g.op('Softplus', input))) def relu(g, input): return g.op("Relu", input) @@ -1978,11 +1980,22 @@ def to(g, self, *args): # aten::to(Tensor, Device, bool, bool, memory_format) return self else: - dtype = sym_help._maybe_get_const(args[0], "i") - if sym_help._is_value(dtype): + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() + # In this case, the constant value is a tensor not int, + # so sym_help._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant": + tval = args[0].node()["value"] + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor): # aten::to(Tensor, Tensor, bool, bool, memory_format) - other = args[0] - dtype = other.type().scalarType() + dtype = args[0].type().scalarType() return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype]) else: # aten::to(Tensor, ScalarType, bool, bool, memory_format) @@ -2014,7 +2027,7 @@ def repeat(g, self, repeats): return g.op("Tile", self, repeats) -def repeat_interleave(g, self, repeats, dim=None): +def repeat_interleave(g, self, repeats, dim=None, output_size=None): input = self # if dim is None flatten # By default, use the flattened input array, and return a flat output array @@ -2569,11 +2582,15 @@ def prim_tolist(g, input, dim_val, elem_ty_val): return input -@parse_args("v", "i") -def one_hot(g, self, num_classes): +@parse_args("v", "i", "v") +def one_hot(g, self, num_classes, dtype): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) depth = g.op("Constant", value_t=torch.LongTensor([num_classes])) - return g.op("OneHot", self, depth, values, axis_i=-1) + one_hot_tensor = g.op("OneHot", self, depth, values, axis_i=-1) + dtype = sym_help._maybe_get_const(dtype, 'i') + if sym_help._is_value(dtype): + dtype = 4 # default to int64 + return g.op("Cast", one_hot_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) @parse_args("v", "i", "v", "v") diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index b1ade48ed18729..973122513d6620 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -151,6 +151,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_fuse_addmm(graph) torch._C._jit_pass_lint(graph) + from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version if operator_export_type != OperatorExportTypes.RAW: torch._C._jit_pass_peephole(graph, True) @@ -192,7 +193,6 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) - from torch.onnx.symbolic_helper import _onnx_shape_inference if _onnx_shape_inference: input_names = [] if input_names is None else input_names dynamic_axes = {} if dynamic_axes is None else dynamic_axes @@ -200,7 +200,6 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) - from torch.onnx.symbolic_helper import _export_onnx_opset_version torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version) torch._C._jit_pass_lint(graph) @@ -218,7 +217,6 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_lint(graph) graph = torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_lint(graph) - from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version if _onnx_shape_inference: torch._C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, _export_onnx_opset_version) return graph @@ -1017,35 +1015,82 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat elif op_name == "device" and n.output().type().kind() == "DeviceObjType": return None elif op_name == "Loop" or op_name == "If": - new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) - new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() - for b in n.blocks(): - new_block = new_node.addBlock() - # Copy input metadata to subblock + static_if = (op_name == "If" and inputs[0].node().kind() == "onnx::Constant") + is_sub_block = False + if static_if: + # Fold static if # - # If format: - # prim::If(cond) - # block0() - # block1() + # The torch IR + # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), + # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... + # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %21 : Long(device=cpu) = aten::eq(%20, %64) + # %22 : Long(device=cpu) = prim::If(%21) + # block0(): + # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) + # -> (%23) + # block1(): + # -> (%65) + # %input.53 : Tensor, %weight : Tensor = prim::If(%22) + # block0(): + # -> (%embedding_matrix.1, %input.1) + # block1(): + # -> (%input.1, %embedding_matrix.1) + # %26 : int[] = aten::size(%input.53) # - # Loop format: - # prim::Loop(iter, cond, input_1, ..., input_n) - # block0(iter, input_1, ..., input_n) - # - # For `If` node, there is nothing to copy. - # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. - for i, b_in in enumerate(b.inputs()): - if i == 0 and i < len(inputs): - b_in.setType(inputs[i].type()) - if i > 0 and (i + 1) < len(inputs): - b_in.setType(inputs[i + 1].type()) - torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) - new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version) - # Process Loop and If after subblock is converted. - from torch.onnx.symbolic_helper import _onnx_shape_inference - if _onnx_shape_inference: - torch._C._jit_pass_onnx_node_shape_type_inference(new_node, _params_dict, opset_version) - return new_op_outputs + # The converted ONNX graph + # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() + # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) + # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() + # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) + input_flag = inputs[0].node()['value'].tolist() + const_value = all(input_flag) if isinstance(input_flag, list) else bool(input_flag) + block_idx = 0 if const_value else 1 + current_b = list(n.blocks())[block_idx] + is_sub_block = True + env = torch._C._jit_pass_onnx_block(current_b, block, operator_export_type, env, + is_sub_block) + if_output_list = list(n.outputs()) + current_b_list = list(current_b.outputs()) + + final_b_list = [] + for idx in range(len(if_output_list)): + if current_b_list[idx] not in env: + raise RuntimeError("The sub block ATen output " + current_b_list[idx] + " is not in env.") + onnx_b = env[current_b_list[idx]] + final_b_list.append(onnx_b) + return final_b_list + else: + new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize()) + new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() + for b in n.blocks(): + new_block = new_node.addBlock() + # Copy input metadata to subblock + # + # If format: + # prim::If(cond) + # block0() + # block1() + # + # Loop format: + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `If` node, there is nothing to copy. + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. + for i, b_in in enumerate(b.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) + if i > 0 and (i + 1) < len(inputs): + b_in.setType(inputs[i + 1].type()) + torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env, + is_sub_block) + new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(new_node, opset_version) + # Process Loop and If after subblock is converted. + from torch.onnx.symbolic_helper import _onnx_shape_inference + if _onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference(new_node, _params_dict, opset_version) + return new_op_outputs else: symbolic_name = "prim_" + op_name domain = '' diff --git a/torch/overrides.py b/torch/overrides.py index acf1e44a443282..e2fee94cb259e4 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -184,6 +184,10 @@ def get_ignored_functions() -> Set[Callable]: torch.set_autocast_enabled, torch.is_autocast_enabled, torch.clear_autocast_cache, + torch.set_autocast_cpu_enabled, + torch.is_autocast_cpu_enabled, + torch.set_autocast_cpu_dtype, + torch.get_autocast_cpu_dtype, torch.autocast_increment_nesting, torch.autocast_decrement_nesting, torch.nn.functional.hardswish, @@ -714,6 +718,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, torch.nn.functional.selu: lambda input, inplace=False: -1, torch.nn.functional.silu: lambda input, inplace=False: -1, + torch.nn.functional.mish: lambda input, inplace=False: -1, torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1, torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1, torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1, diff --git a/torch/package/_importlib.py b/torch/package/_importlib.py index d6d52311987a67..63c9d7024bfb80 100644 --- a/torch/package/_importlib.py +++ b/torch/package/_importlib.py @@ -1,6 +1,5 @@ -import os.path - import _warnings +import os.path # note: implementations # copied from cpython's import code diff --git a/torch/package/_package_unpickler.py b/torch/package/_package_unpickler.py index 51bcd4c03764a5..c55905850a83f6 100644 --- a/torch/package/_package_unpickler.py +++ b/torch/package/_package_unpickler.py @@ -1,6 +1,5 @@ -import pickle - import _compat_pickle +import pickle from .importer import Importer diff --git a/torch/package/analyze/is_from_package.py b/torch/package/analyze/is_from_package.py index d47a6ea6d45943..82ff5896b6ffcc 100644 --- a/torch/package/analyze/is_from_package.py +++ b/torch/package/analyze/is_from_package.py @@ -1,6 +1,7 @@ -from .._mangling import is_mangled -from typing import Any from types import ModuleType +from typing import Any + +from .._mangling import is_mangled def is_from_package(obj: Any) -> bool: diff --git a/torch/package/analyze/trace_dependencies.py b/torch/package/analyze/trace_dependencies.py index ed96a1b02fb2ed..7ee4e8ca27f15d 100644 --- a/torch/package/analyze/trace_dependencies.py +++ b/torch/package/analyze/trace_dependencies.py @@ -2,7 +2,9 @@ from typing import Any, Callable, Iterable, List, Tuple -def trace_dependencies(callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]]) -> List[str]: +def trace_dependencies( + callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]] +) -> List[str]: """Trace the execution of a callable in order to determine which modules it uses. Args: diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 440ab23cbcedc6..bc0cfe386ca4fd 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -14,7 +14,7 @@ def __init__(self, name: str, is_dir: bool): self.is_dir = is_dir self.children: Dict[str, Directory] = {} - def _get_dir(self, dirs: List[str]) -> 'Directory': + def _get_dir(self, dirs: List[str]) -> "Directory": """Builds path of Directories if not yet built and returns last directory in list. @@ -107,7 +107,7 @@ def _create_directory_from_file_list( include: "GlobPattern" = "**", exclude: "GlobPattern" = (), ) -> Directory: - """ Return a :class:`Directory` file structure representation created from a list of files. + """Return a :class:`Directory` file structure representation created from a list of files. Args: filename (str): The name given to the top-level directory that will be the diff --git a/torch/package/glob_group.py b/torch/package/glob_group.py index 40a41ae820f26c..a8434788d016fd 100644 --- a/torch/package/glob_group.py +++ b/torch/package/glob_group.py @@ -35,6 +35,7 @@ class GlobGroup: named in Python. Another common value for this is "/", which is the Unix path separator. """ + def __init__( self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "." ): diff --git a/torch/package/importer.py b/torch/package/importer.py index d6039accbdda43..37b2000b205604 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -40,6 +40,7 @@ class Importer(ABC): obj2 = getattr(module, obj_name) assert obj1 is obj2 """ + modules: Dict[str, ModuleType] @abstractmethod diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 6239ae3721dd89..22a9ced5221089 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -57,10 +57,13 @@ class PackagingErrorReason(Enum): This enum is used to provide good error messages when :class:`PackagingError` is raised. """ + def __repr__(self): - return '<%s.%s>' % (self.__class__.__name__, self.name) + return "<%s.%s>" % (self.__class__.__name__, self.name) - IS_EXTENSION_MODULE = "Module is a C extension module. torch.package supports Python modules only." + IS_EXTENSION_MODULE = ( + "Module is a C extension module. torch.package supports Python modules only." + ) NO_DUNDER_FILE = "Module had no __file__ defined." SOURCE_FILE_NOT_FOUND = ( "Module had a __file__, but we could not find it in your filesystem." @@ -722,7 +725,11 @@ def _persistent_id(self, obj): "`torch.package.package_exporter._gate_torchscript_serialization` to `False`." ) if self.serialized_reduces.get(id(obj)) is None: - self.serialized_reduces[id(obj)] = ("reduce_package", id(obj), *obj.__reduce_package__(self)) + self.serialized_reduces[id(obj)] = ( + "reduce_package", + id(obj), + *obj.__reduce_package__(self), + ) return self.serialized_reduces[id(obj)] diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index fa56803ab6ce8f..0f3d3bd4ae8691 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -1,11 +1,11 @@ import builtins -from contextlib import contextmanager import importlib import inspect import io import linecache import os.path import types +from contextlib import contextmanager from pathlib import Path from typing import Any, BinaryIO, Callable, Dict, List, Optional, Union from weakref import WeakValueDictionary @@ -13,8 +13,6 @@ import torch from torch.serialization import _get_restore_location, _maybe_decode_ascii -from .file_structure_representation import Directory, _create_directory_from_file_list -from .glob_group import GlobPattern from ._importlib import ( _calc___package__, _normalize_line_endings, @@ -25,6 +23,8 @@ from ._mangling import PackageMangler, demangle from ._mock_zipreader import MockZipReader from ._package_unpickler import PackageUnpickler +from .file_structure_representation import Directory, _create_directory_from_file_list +from .glob_group import GlobPattern from .importer import Importer @@ -286,7 +286,12 @@ def _make_module( self, name: str, filename: Optional[str], is_package: bool, parent: str ): mangled_filename = self._mangler.mangle(filename) if filename else None - spec = importlib.machinery.ModuleSpec(name, self, is_package=is_package) # type: ignore[arg-type] + spec = importlib.machinery.ModuleSpec( + name, + self, # type: ignore[arg-type] + origin="", + is_package=is_package, + ) module = importlib.util.module_from_spec(spec) self.modules[name] = module module.__name__ = self._mangler.mangle(name) diff --git a/torch/quantization/_equalize.py b/torch/quantization/_equalize.py index 51cbe0773e24de..1da025ca7a0d2f 100644 --- a/torch/quantization/_equalize.py +++ b/torch/quantization/_equalize.py @@ -3,6 +3,32 @@ from typing import Dict, Any _supported_types = {torch.nn.Conv2d, torch.nn.Linear} +_supported_intrinsic_types = {torch.nn.intrinsic.ConvReLU2d, torch.nn.intrinsic.LinearReLU} +_all_supported_types = _supported_types.union(_supported_intrinsic_types) + +def set_module_weight(module, weight) -> None: + if type(module) in _supported_types: + module.weight = torch.nn.Parameter(weight) + else: + module[0].weight = torch.nn.Parameter(weight) + +def set_module_bias(module, bias) -> None: + if type(module) in _supported_types: + module.bias = torch.nn.Parameter(bias) + else: + module[0].bias = torch.nn.Parameter(bias) + +def get_module_weight(module): + if type(module) in _supported_types: + return module.weight + else: + return module[0].weight + +def get_module_bias(module): + if type(module) in _supported_types: + return module.bias + else: + return module[0].bias def max_over_ndim(input, axis_list, keepdim=False): ''' Applies 'torch.max' over the given axises @@ -38,16 +64,17 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): the ranges of the first tensors' output channel are equal to the ranges of the second tensors' input channel ''' - if type(module1) not in _supported_types or type(module2) not in _supported_types: + if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types: raise ValueError("module type not supported:", type(module1), " ", type(module2)) - if module1.weight.size(output_axis) != module2.weight.size(input_axis): + weight1 = get_module_weight(module1) + weight2 = get_module_weight(module2) + + if weight1.size(output_axis) != weight2.size(input_axis): raise TypeError("Number of output channels of first arg do not match \ number input channels of second arg") - weight1 = module1.weight - weight2 = module2.weight - bias = module1.bias + bias = get_module_bias(module1) weight1_range = channel_range(weight1, output_axis) weight2_range = channel_range(weight2, input_axis) @@ -72,9 +99,9 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): weight1 = weight1 * inverse_scaling_factors weight2 = weight2 * scaling_factors - module1.weight = torch.nn.Parameter(weight1) - module1.bias = torch.nn.Parameter(bias) - module2.weight = torch.nn.Parameter(weight2) + set_module_weight(module1, weight1) + set_module_bias(module1, bias) + set_module_weight(module2, weight2) def equalize(model, paired_modules_list, threshold=1e-4, inplace=True): ''' Given a list of adjacent modules within a model, equalization will @@ -131,6 +158,9 @@ def converged(curr_modules, prev_modules, threshold=1e-4): if None in prev_modules.values(): return False for name in curr_modules.keys(): - difference = curr_modules[name].weight.sub(prev_modules[name].weight) + curr_weight = get_module_weight(curr_modules[name]) + prev_weight = get_module_weight(prev_modules[name]) + + difference = curr_weight.sub(prev_weight) summed_norms += torch.norm(difference) return bool(summed_norms < threshold) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index be9d98164a8c4b..58ea56dbe0b761 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -50,7 +50,8 @@ def fuse_conv_bn_relu(conv, bn, relu): >>> m1 = nn.Conv2d(10, 20, 3) >>> b1 = nn.BatchNorm2d(20) - >>> m2 = fuse_conv_bn(m1, b1) + >>> r1 = nn.ReLU(inplace=False) + >>> m2 = fuse_conv_bn_relu(m1, b1, r1) """ assert(conv.training == bn.training == relu.training),\ "Conv and BN both must be in the same mode (train or eval)." @@ -93,7 +94,7 @@ def fuse_linear_bn(linear, bn): >>> m1 = nn.Linear(20, 10) >>> b1 = nn.BatchNorm1d(10) - >>> m2 = fuse_conv_bn(m1, b1) + >>> m2 = fuse_linear_bn(m1, b1) """ assert(linear.training == bn.training),\ "Linear and BN both must be in the same mode (train or eval)." diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py index 2b2938cdd0a739..1b528d7677d3cb 100644 --- a/torch/quantization/fx/qconfig_utils.py +++ b/torch/quantization/fx/qconfig_utils.py @@ -1,13 +1,19 @@ import torch from collections import OrderedDict -from typing import Union, Callable, Any, Dict +from typing import Union, Callable, Any, Dict, Tuple, Set import re +from torch.fx.graph import ( + Graph, +) + from .utils import _parent_name +# TODO: move this to quantization_types.py? QConfigAny = Union[torch.quantization.QConfig, torch.quantization.QConfigDynamic, None] + def get_flattened_qconfig_dict(qconfig_dict): """ flatten the global, object_type and module_name qconfig to the same qconfig_dict so that it can be used by @@ -45,6 +51,7 @@ def flatten_key(key): flatten_key('module_name') return flattened + def convert_dict_to_ordered_dict(qconfig_dict: Any) -> Dict[str, Dict[Any, Any]]: """ Convert dict in qconfig_dict to ordered dict """ @@ -57,6 +64,7 @@ def _convert_to_ordered_dict(key, qconfig_dict): _convert_to_ordered_dict('module_name', qconfig_dict) return qconfig_dict + def get_object_type_qconfig( qconfig_dict: Any, object_type: Union[Callable, str], @@ -68,6 +76,7 @@ def get_object_type_qconfig( return qconfig_dict['object_type'].get( object_type, fallback_qconfig) + def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig): for regex_pattern, qconfig in \ qconfig_dict['module_name_regex'].items(): @@ -76,6 +85,7 @@ def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig): return qconfig return fallback_qconfig + def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): if module_name == '': # module name qconfig not found @@ -89,6 +99,8 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): # get qconfig for module_name, # fallback to module_name_regex_qconfig, module_type_qconfig, # global_qconfig if necessary + + def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig): module_type_qconfig = get_object_type_qconfig( qconfig_dict, module_type, global_qconfig) @@ -97,3 +109,125 @@ def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig): module_name_qconfig = get_module_name_qconfig( qconfig_dict, module_name, module_name_regex_qconfig) return module_name_qconfig + + +def generate_qconfig_map( + root: torch.nn.Module, + modules: Dict[str, torch.nn.Module], + input_graph: Graph, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]: + global_qconfig = qconfig_dict.get("", None) + qconfig_map = dict() + for node in input_graph.nodes: + qconfig = None + if node.op == "get_attr": + module_name, _ = _parent_name(node.target) + qconfig = get_qconfig( + qconfig_dict, type(modules[module_name]), module_name, global_qconfig) + elif node.op == "call_function": + # precedence: module_name_qconfig + # > function_qconfig > global_qconfig + # module_name takes precedence over function qconfig + function_qconfig = get_object_type_qconfig( + qconfig_dict, node.target, global_qconfig) + module_path, module_type = node_name_to_scope[node.name] + qconfig = get_qconfig( + qconfig_dict, module_type, module_path, function_qconfig) + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # use the qconfig of the module that the node belongs to + qconfig = get_qconfig( + qconfig_dict, module_type, module_path, global_qconfig) + elif node.op == 'call_module': + qconfig = get_qconfig( + qconfig_dict, type(modules[node.target]), node.target, global_qconfig) + # regex is not supported eager mode propagate_qconfig_, we'll + # need to set the qconfig explicitly here in case regex + # is used + modules[node.target].qconfig = qconfig + qconfig_map[node.name] = qconfig + return qconfig_map + + +def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None: + r""" Checks if the given config_dict has the correct keys + + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict.keys(): + if k not in allowed_keys: + raise ValueError( + 'Expected ' + dict_name + ' to have the following keys: ' + + str(allowed_keys) + '. But found \'' + k + + '\' instead.') + + +def check_is_valid_qconfig_dict(qconfig_dict: Any) -> None: + r""" Checks if the given qconfig_dict has the correct keys + + Args: + `qconfig_dict`: dictionary whose keys we want to check + """ + + qconfig_dict_allowed_keys = {"", "object_type", "module_name_regex", "module_name"} + check_is_valid_config_dict(qconfig_dict, qconfig_dict_allowed_keys, "qconfig_dict") + + +def check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict: Dict[str, Any] = None) -> None: + r""" Checks if the given prepare_custom_config_dict has the correct keys + + Args: + `prepare_custom_config_dict`: customization configuration dictionary for + quantization tool + """ + if not prepare_custom_config_dict: + return + + prepare_custom_config_dict_allowed_keys = {"standalone_module_name", + "standalone_module_class", + "float_to_observed_custom_module_class", + "non_traceable_module_name", + "non_traceable_module_class", + "additional_fuser_method_mapping", + "additional_qat__module_mapping", + "additional_fusion_pattern", + "additional_quant_pattern", + "input_quantized_idxs", + "output_quantized_idxs", + "preserved_attributes"} + check_is_valid_config_dict(prepare_custom_config_dict, + prepare_custom_config_dict_allowed_keys, "prepare_custom_config_dict") + + +def check_is_valid_convert_custom_config_dict(convert_custom_config_dict: Dict[str, Any] = None) -> None: + r""" Checks if the given convert_custom_config_dict has the correct keys + + Args: + `convert_custom_config_dict`: dictionary for custom configurations for + convert function + """ + if not convert_custom_config_dict: + return + + convert_custom_config_dict_allowed_keys = {"additional_object_mapping", + "observed_to_quantized_custom_module_class", + "preserved_attributes"} + check_is_valid_config_dict(convert_custom_config_dict, + convert_custom_config_dict_allowed_keys, "convert_custom_config_dict") + + +def check_is_valid_fuse_custom_config_dict(fuse_custom_config_dict: Dict[str, Any] = None) -> None: + r""" Checks if the given fuse_custom_config_dict has the correct keys + + Args: + `fuse_custom_config_dict`: dictionary for custom configurations for fuse_fx + """ + if not fuse_custom_config_dict: + return + + fuse_custom_config_dict_allowed_keys = {"additional_fuser_method_mapping", + "preserved_attributes"} + check_is_valid_config_dict(fuse_custom_config_dict, fuse_custom_config_dict_allowed_keys, "fuse_custom_config_dict") diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index e0b8cbaeaa830a..1635a81484f703 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -45,6 +45,7 @@ ) from .quantization_types import QuantizerCls +from .qconfig_utils import QConfigAny from abc import ABC, abstractmethod import operator @@ -134,7 +135,7 @@ def get_activation_ctr( @abstractmethod - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: """ Convert the given node to a quantized node and insert @@ -258,7 +259,7 @@ def input_output_observed(self): # for x + y where x and y are scalars, we do not observe anything return self.num_tensor_args > 0 - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: @@ -267,7 +268,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.node_copy( node, load_arg(quantized=None)) - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) if is_reference and self.binary_op in binary_reference_op_supported_dtypes and \ @@ -315,11 +315,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, scale = float(scale) zero_point = int(zero_point) scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point) - - if self.relu_node is not None: - op = torch.ops.quantized.add_relu - else: - op = torch.ops.quantized.add kwargs = {**self.binary_op_node.kwargs} add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg) op = quantizer.quantized_graph.create_node( @@ -333,10 +328,13 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, relu_args = [op_out] relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:])) relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs) - return quantizer.quantized_graph.create_node( + op_out = quantizer.quantized_graph.create_node( "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) else: - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.create_node( + "call_method", "to", (op_out, torch.float16,), {} + ) else: # leave the op unquantized if the dtype,reference combination is not supported warnings.warn( @@ -367,7 +365,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, @register_quant_pattern(torch.cat) class CatQuantizeHandler(QuantizeHandler): - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: if not self.all_node_args_are_tensors: @@ -421,7 +419,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): elif node.op == "call_function": self.conv = node.target - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: # Supported combinations are: @@ -434,7 +432,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] # TODO: is_reference option for conv module - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: @@ -574,7 +571,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): if node.op == 'call_module': self.linear = quantizer.modules[self.linear_node.target] - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: # Supported combinations are: @@ -590,7 +587,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, # static float16 quantization (torch.float16, torch.float16, None), ] - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: @@ -609,6 +605,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) activation_int8_quantized = activation_is_int8_quantized(qconfig) + activation_statically_quantized = activation_is_statically_quantized(qconfig) weight_dtype = dtypes[1] # TODO: reference_model option for linear module if self.linear_node.op == 'call_module': @@ -661,7 +658,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, op_out = quantizer.quantized_graph.create_node( "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) - if activation_int8_quantized: + if activation_statically_quantized: # quantize output for statically quantized linear op root_module = quantizer.modules[''] act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name @@ -748,10 +745,12 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, relu_args = [op_out] relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:])) relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs) - return quantizer.quantized_graph.create_node( + op_out = quantizer.quantized_graph.create_node( "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) else: - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.create_node( + "call_method", "to", (op_out, torch.float16), {}) @register_quant_pattern(torch.nn.BatchNorm2d) @register_quant_pattern(torch.nn.BatchNorm3d) @@ -764,7 +763,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): self.bn_node = node self.bn = quantizer.modules[self.bn_node.target] - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: if convert_custom_config_dict is None: @@ -794,7 +793,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): def input_output_observed(self) -> bool: return False - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: # Supported combinations are: @@ -808,7 +807,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] assert node.op == 'call_module' emb_node = node - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: @@ -841,7 +839,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): def input_output_observed(self) -> bool: return False - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: # Supported combinations are: @@ -854,7 +852,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, (torch.float32, torch.float16, None), ] assert node.op == 'call_module' - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: @@ -890,6 +887,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, @register_quant_pattern(torch.nn.InstanceNorm3d) @register_quant_pattern(torch.nn.LayerNorm) @register_quant_pattern(torch.nn.SiLU) +@register_quant_pattern(torch.nn.Mish) # we currently only support reference patterns for these ops so they have been removed # until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig # @register_quant_pattern(torch.nn.GELU) @@ -899,6 +897,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, @register_quant_pattern(torch.nn.functional.layer_norm) @register_quant_pattern(torch.nn.functional.leaky_relu) @register_quant_pattern(torch.nn.functional.silu) +@register_quant_pattern(torch.nn.functional.mish) # we currently only support reference patterns for these ops so they have been removed # until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig # @register_quant_pattern(torch.nn.functional.gelu) @@ -914,7 +913,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): elif node.op == "call_module": self.op = type(quantizer.modules[node.target]) - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: if not self.all_node_args_are_tensors: @@ -946,6 +945,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, torch.nn.InstanceNorm3d: int8_dtypes, torch.nn.LayerNorm: all_dtypes, torch.nn.SiLU: fp16_dtypes, + torch.nn.Mish: fp16_dtypes, torch.nn.GELU: int8_dtypes, torch.nn.Softmax: int8_dtypes, torch.nn.functional.hardswish: int8_dtypes, @@ -953,10 +953,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, torch.nn.functional.layer_norm: all_dtypes, torch.nn.functional.leaky_relu: int8_dtypes, torch.nn.functional.silu: fp16_dtypes, + torch.nn.functional.mish: fp16_dtypes, torch.nn.functional.gelu: int8_dtypes, torch.nn.functional.softmax: int8_dtypes, } - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) if dtypes not in supported_dtypes[self.op]: warnings.warn( @@ -1011,7 +1011,9 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, warnings.warn( "Only reference patterns are currently supported for {dtype} dtype with {op} op" "".format(dtype=dtypes, op=self.op)) - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.create_node( + "call_method", "to", (op_out, torch.float16), {}) else: assert is_reference if dtypes in [(torch.quint8, torch.qint8, None)]: @@ -1027,13 +1029,15 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, node, is_input=False) else: assert dtypes in [(torch.float16, torch.float16, None)] - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.create_node( + "call_method", "to", (op_out, torch.float16), {}) # TODO: elu is using scale/zero_point instead of output_scale, output_zero_point @register_quant_pattern(torch.nn.functional.elu) class ELUQuantizeHandler(QuantizeHandler): - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: activation_post_process = \ @@ -1094,13 +1098,15 @@ def get_activation_ctr(self, qconfig, pattern) -> Optional[Callable]: return get_default_output_activation_post_process_map().get( pattern, None) - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: - qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) if dtypes == (torch.float16, torch.float16, None): - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.create_node( + "call_method", "to", (op_out, torch.float16,), {} + ) else: return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) @@ -1177,13 +1183,13 @@ def should_mark_output_quantized_from_input_quantized_status( ) -> bool: return True - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) class CustomModuleQuantizeHandler(QuantizeHandler): - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: """ Convert a float custom module to quantized custom module @@ -1192,7 +1198,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, assert convert_custom_config_dict is not None custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) assert custom_module_class_mapping is not None - qconfig = quantizer.qconfig_map[node.name] observed_custom_module = quantizer.modules[node.target] if activation_is_statically_quantized(qconfig): assert node.name in quantizer.activation_post_process_map @@ -1216,11 +1221,10 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler): """ Converts an observed standalone module to quantized standalone module by calling convert_fx on the observed standalone module. """ - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + def convert(self, quantizer: QuantizerCls, node: Node, qconfig: QConfigAny, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: assert node.op == 'call_module' - qconfig = quantizer.qconfig_map[node.name] convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore[attr-defined] observed_standalone_module = quantizer.modules[node.target] input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist() diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index eac2c0789c91ce..d71863388ad048 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -77,9 +77,8 @@ from .qconfig_utils import ( convert_dict_to_ordered_dict, + generate_qconfig_map, get_flattened_qconfig_dict, - get_object_type_qconfig, - get_qconfig, QConfigAny, ) @@ -94,6 +93,12 @@ # ------------------------ # Helper Functions # ------------------------ +def qat_swap_modules( + root: torch.nn.Module, + additional_qat_module_mapping: Dict[Callable, Callable]) -> None: + all_mappings = get_combined_dict( + get_default_qat_module_mappings(), additional_qat_module_mapping) + convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False) def get_standalone_module_configs( node: Node, @@ -930,9 +935,6 @@ def node_arg_is_bias(node: Node, arg: Any) -> bool: class Quantizer: def __init__(self): - # mapping from node name to qconfig that should be used for that node - # filled out for a model during _generate_qconfig_map - self.qconfig_map: Dict[str, QConfigAny] = {} # mapping from fully qualified module name to module instance # for example, # { @@ -957,54 +959,6 @@ def __init__(self): # mapping from node name to the scope of the module which contains the node. self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} - - def _qat_swap_modules( - self, root: torch.nn.Module, - additional_qat_module_mapping: Dict[Callable, Callable]) -> None: - all_mappings = get_combined_dict( - get_default_qat_module_mappings(), additional_qat_module_mapping) - convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False) - - def _generate_qconfig_map( - self, - root: torch.nn.Module, - input_graph: Graph, - qconfig_dict: Any, - node_name_to_scope: Dict[str, Tuple[str, type]]) -> None: - global_qconfig = qconfig_dict.get("", None) - self.node_name_to_scope = node_name_to_scope - self.qconfig_map = dict() - for node in input_graph.nodes: - if node.op == "get_attr": - module_name, _ = _parent_name(node.target) - self.qconfig_map[node.name] = get_qconfig( - qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig) - elif node.op == "call_function": - # precedence: [TODO] module_name_qconfig (need scope support - # from fx) - # > function_qconfig > global_qconfig - # module_name takes precedence over function qconfig - function_qconfig = get_object_type_qconfig( - qconfig_dict, node.target, global_qconfig) - module_path, module_type = node_name_to_scope[node.name] - qconfig = get_qconfig( - qconfig_dict, module_type, module_path, function_qconfig) - self.qconfig_map[node.name] = qconfig - elif node.op == "call_method": - module_path, module_type = node_name_to_scope[node.name] - # use the qconfig of the module that the node belongs to - qconfig = get_qconfig( - qconfig_dict, module_type, module_path, global_qconfig) - self.qconfig_map[node.name] = qconfig - elif node.op == 'call_module': - module_qconfig = get_qconfig( - qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig) - # regex is not supported eager mode propagate_qconfig_, we'll - # need to set the qconfig explicitly here in case regex - # is used - self.modules[node.target].qconfig = module_qconfig - self.qconfig_map[node.name] = module_qconfig - def _prepare( self, model: GraphModule, @@ -1045,12 +999,13 @@ def _prepare( if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) - self._qat_swap_modules(model, additional_qat_module_mapping) + qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) - # fill self.qconfig_map, a map from node name to qconfig, used in _find_matches - self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope) + self.node_name_to_scope = node_name_to_scope + # fill qconfig_map, a map from node name to qconfig, used in _find_matches + qconfig_map = generate_qconfig_map(model, self.modules, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized standalone_module_name_configs = prepare_custom_config_dict.get( @@ -1063,7 +1018,7 @@ def _prepare( custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches( - model.graph, self.modules, self.patterns, standalone_module_names, + model.graph, self.modules, self.patterns, qconfig_map, standalone_module_names, standalone_module_classes, custom_module_classes) input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( @@ -1075,11 +1030,11 @@ def _prepare( model, self.modules, matches, prepare_custom_config_dict) result_node = insert_observers_for_model( - model, self.modules, matches, self.qconfig_map, + model, self.modules, matches, qconfig_map, model.graph, prepare_custom_config_dict, input_quantized_idxs, output_quantized_idxs) - self.save_state(model) + self.save_state(model, qconfig_map) preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", [])) model = ObservedGraphModule(model, model.graph, preserved_attributes) if is_standalone_module: @@ -1095,9 +1050,9 @@ def _prepare( model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs) return model - def save_state(self, observed: GraphModule) -> None: + def save_state(self, observed: GraphModule, qconfig_map: Dict[str, QConfigAny]) -> None: observed._patterns = self.patterns # type: ignore[assignment] - observed._qconfig_map = self.qconfig_map # type: ignore[assignment] + observed._qconfig_map = qconfig_map # type: ignore[assignment] observed._prepare_custom_config_dict = \ self.prepare_custom_config_dict # type: ignore[assignment] observed._node_name_to_scope = self.node_name_to_scope # type: ignore[assignment] @@ -1106,7 +1061,6 @@ def restore_state(self, observed: GraphModule) -> None: assert is_observed_module(observed), \ 'incoming model must be produced by prepare_fx' self.patterns = observed._patterns # type: ignore[assignment] - self.qconfig_map = observed._qconfig_map # type: ignore[assignment] self.prepare_custom_config_dict = \ observed._prepare_custom_config_dict # type: ignore[assignment] self.node_name_to_scope = observed._node_name_to_scope # type: ignore[assignment] @@ -1157,6 +1111,7 @@ def _convert(self, model: GraphModule, is_reference: bool = False, if convert_custom_config_dict is None: convert_custom_config_dict = {} self.restore_state(model) + qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops self._run_weight_observers(model) @@ -1170,6 +1125,7 @@ def _convert(self, model: GraphModule, is_reference: bool = False, "observed_to_quantized_custom_module_class") matches = self._find_matches( model.graph, self.modules, self.patterns, + qconfig_map, custom_module_classes=custom_module_classes) self.quantized_graph = Graph() @@ -1302,9 +1258,8 @@ def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: quantized = node_arg_is_quantized(node.args[0]) # the output is unquantized if the node is not a CopyNode - # and activation is fp16 (since we will output fp32 currently for fp16 - # converter - if not activation_is_int8_quantized(qconfig) or \ + # or the activation is not statically quantized + if not activation_is_statically_quantized(qconfig) or \ not obj.input_output_observed(): quantized = False if node_return_type_is_int(node): @@ -1392,8 +1347,9 @@ def insert_quantize_node(node: Node) -> None: assert len(out_quant_idxs) <= 1, "Currently standalone only support one output" quantized = 0 in out_quant_idxs + qconfig = qconfig_map[node.name] result = obj.convert( - self, node, load_arg, is_reference=is_reference, + self, node, qconfig, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj) @@ -1527,6 +1483,7 @@ def convert(self, model: GraphModule, is_reference: bool = False, def _find_matches( self, graph: Graph, modules: Dict[str, torch.nn.Module], patterns: Dict[Pattern, QuantizeHandler], + qconfig_map: Dict[str, QConfigAny], standalone_module_names: List[str] = None, standalone_module_classes: List[Callable] = None, custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]: @@ -1595,7 +1552,7 @@ def record_match(pattern, node, matched): base_node = node.args[0] this_node_qconfig = \ - self.qconfig_map[base_node.name] + qconfig_map[base_node.name] if this_node_qconfig: dtypes = get_qconfig_dtypes(this_node_qconfig) # TODO(future PR): update the pattern to quantize @@ -1629,7 +1586,7 @@ def record_match(pattern, node, matched): for n in matched: match_map[n.name] = ( node, matched, pattern, value(self, node), # type: ignore[operator] - self.qconfig_map[n.name]) + qconfig_map[n.name]) all_matched.add(n.name) # break after finding the first match break @@ -1639,7 +1596,7 @@ def record_match(pattern, node, matched): for node in graph.nodes: if node.op == 'call_module' and \ type(self.modules[node.target]) in custom_module_classes: - custom_module_qconfig = self.qconfig_map[node.name] + custom_module_qconfig = qconfig_map[node.name] match_map[node.name] = ( node, [node], None, CustomModuleQuantizeHandler(self, node), custom_module_qconfig) @@ -1657,7 +1614,7 @@ def is_standalone_module(node_target): (is_standalone_module(node.target) or is_observed_standalone_module(self.modules[node.target])): # add node to matched nodes - custom_module_qconfig = self.qconfig_map[node.name] + custom_module_qconfig = qconfig_map[node.name] match_map[node.name] = ( node, [node], None, StandaloneModuleQuantizeHandler(self, node), diff --git a/torch/quantization/ns/mappings.py b/torch/quantization/ns/mappings.py index deb8661d5fa396..b963a0b5c30c58 100644 --- a/torch/quantization/ns/mappings.py +++ b/torch/quantization/ns/mappings.py @@ -271,6 +271,11 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: nn.SiLU, F.silu, ]), + # F.mish + set([ + nn.Mish, + F.mish, + ]), # F.tanh set([ nn.Tanh, @@ -409,6 +414,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: F.layer_norm, F.leaky_relu, F.silu, + F.mish, # TODO(future PR): implement shadowing for binary ops and # uncomment below # operator.add, @@ -503,6 +509,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nn.LeakyReLU, nn.ReLU6, nn.SiLU, + nn.Mish, nni.BNReLU2d, nni.BNReLU3d, nni.ConvReLU1d, diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 8873fea280614f..0d6bb075dd75fa 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -7,6 +7,11 @@ from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule +from .fx.qconfig_utils import ( + check_is_valid_convert_custom_config_dict, + check_is_valid_fuse_custom_config_dict, + check_is_valid_prepare_custom_config_dict, + check_is_valid_qconfig_dict) from torch.nn.intrinsic import _FusedModule from typing import Dict, Any, List, Callable, Tuple, Optional, Set @@ -150,6 +155,9 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, if prepare_custom_config_dict is None: prepare_custom_config_dict = {} + check_is_valid_qconfig_dict(qconfig_dict) + check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict) + skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", []) skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", []) @@ -241,6 +249,7 @@ def fuse_fx(model: torch.nn.Module, """ torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") assert not model.training, 'fuse_fx only works on models in eval mode' + check_is_valid_fuse_custom_config_dict(fuse_custom_config_dict) graph_module = torch.fx.symbolic_trace(model) preserved_attributes: Set[str] = set() if fuse_custom_config_dict: @@ -439,6 +448,7 @@ def _convert_fx( convert_custom_config_dict = {} _check_is_graph_module(graph_module) + check_is_valid_convert_custom_config_dict(convert_custom_config_dict) quantizer = Quantizer() quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict, diff --git a/torch/quantization/quantize_jit.py b/torch/quantization/quantize_jit.py index e6df61303a877c..f449efe5a6d344 100644 --- a/torch/quantization/quantize_jit.py +++ b/torch/quantization/quantize_jit.py @@ -77,9 +77,11 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, model_c = model._c model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type) if not debug: - # Moving model parameters to CPU since quantized operators - # are only supported on CPU right now - model.cpu() + is_xpu = all(p.device.type == 'xpu' for p in model.parameters()) + if not is_xpu: + # Moving model parameters to CPU since quantized operators + # are only supported on CPU and XPU right now + model.cpu() if preserved_attrs is None: preserved_attrs = [] model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs) diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 1c84a3d8711e00..5283ccecd886d6 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -174,6 +174,9 @@ def __init__(self, dev): ("renorm", mat0_fp16 + (2, 0, 1.0)), ("sum", pointwise0_fp16), ("sum", mat0_fp16 + (1,)), + ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev), + torch.randn((2, 22, 11, 2), dtype=torch.float16, device=dev), + 0, 0, False)), ] self.torch_need_autocast_promote = [ ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)), @@ -234,3 +237,123 @@ def __init__(self, dev): ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32), torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn), ] + +class AutocastCPUTestLists(object): + # Supplies ops and arguments for test_autocast_* in test/test_cpu.py + def __init__(self, dev): + super().__init__() + n = 8 + # Utility arguments, created as one-element tuples + pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) + mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) + + dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n)) + + dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),) + for dimset in dummy_dimsets] + + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev), + torch.randn(dimset, dtype=torch.bfloat16, device=dev)) + for dimset in dimsets] + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + + bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),) + for dimset in dummy_dimsets] + # The lists below organize ops that autocast needs to test. + # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py. + # Each op is associated with a tuple of valid arguments. + + # Some ops implement built-in type promotion. These don't need autocasting, + # but autocasting relies on their promotion, so we include tests to double-check. + self.torch_expect_builtin_promote = [ + ("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("le", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("add", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("div", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ] + self.methods_expect_builtin_promote = [ + ("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool), + ("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32), + ] + # The remaining lists organize ops that autocast treats explicitly. + self.torch_bf16 = [ + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("log_softmax", pointwise0_fp32 + (0,)), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("mm", mat0_fp32 + mat1_fp32), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ] + self.torch_fp32 = [ + ("conv_transpose3d", conv_args_bf16[2]), + ("batch_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32), + "running_var": torch.rand((n), dtype=torch.float32), "training": False, + "momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}), + ("max_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}), + ("dropout", dummy_bf16[2], {"p": 0.1, "train": False}), + ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), + ("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_bf16), + ("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)), + ("instance_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32), + "running_var": torch.rand((n), dtype=torch.float32), "use_input_stats": False, + "momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}), + ] + self.nn_bf16 = [ + ("linear", mat0_fp32 + mat1_fp32), + ] + self.nn_fp32 = [ + ("adaptive_avg_pool2d", dummy_bf16[2], {"output_size": (3, 2)}), + ("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}), + ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}), + ("gelu", dummy_bf16[3]), + ("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}), + ("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}), + ("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}), + ("upsample_linear1d", dummy_bf16[2], {"output_size": (n), "align_corners": False}), + ("upsample_bilinear2d", dummy_bf16[3], {"output_size": (n, n), "align_corners": False}), + ("upsample_trilinear3d", dummy_bf16[4], {"output_size": (n, n, n), "align_corners": False}), + ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) + + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), + ("smooth_l1_loss", mat0_bf16 + mat1_bf16), + ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}), + ("std", dummy_bf16[2]), + ] + self.torch_need_autocast_promote = [ + ("cat", (pointwise0_bf16 + pointwise1_fp32,)), + ("stack", (pointwise0_bf16 + pointwise1_fp32,)), + ] diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 5427cb455127e0..73b6b59a4e5180 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -769,7 +769,7 @@ def only_fn(self, device, *args, **kwargs): # Specifies per-dtype precision overrides. # Ex. # -# @precisionOverride(torch.half : 1e-2, torch.float : 1e-4) +# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4}) # @dtypes(torch.half, torch.float, torch.double) # def test_X(self, device, dtype): # ... diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b23c8ed9f101cb..b3d18480956a58 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1,5 +1,5 @@ from functools import reduce, wraps, partial -from itertools import product +from itertools import product, chain from operator import mul import collections import operator @@ -81,9 +81,9 @@ def __init__(self, cls_name=None, test_name=None, *, class SampleInput(object): """Represents sample inputs to a function.""" - __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_input'] + __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_input', 'name'] - def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_input=False): + def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_input=False, name=""): # input is the first input to the op and must be either a Tensor or TensorList (Sequence[Tensor]). # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...). # op with TensorList inputs do not support method or inplace variants. @@ -92,6 +92,7 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N self.args = args self.kwargs = kwargs if kwargs is not None else {} self.output_process_fn_grad = output_process_fn_grad + self.name = name # Specifies if `self.input` is broadcasted or not, # given that the operator supports broadcasting. @@ -103,17 +104,56 @@ def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=N # for such inputs (as they will error out otherwise). self.broadcasts_input = broadcasts_input - def __repr__(self): + def _repr_helper(self, formatter): + # Helper function to return the details of the SampleInput as `str` + # It consolidates all the fields of SampleInput and allows, + # formatting the fields like `input`, `args`, etc with `formatter` + # callable to customize the representation. + # Look at `summary` method for example. arguments = [ - 'input=Tensor' if isinstance(self.input, torch.Tensor) else f'input=TensorList[{len(self.input)}]', - f'args={self.args}' if len(self.args) > 0 else None, - f'kwargs={self.kwargs}' if len(self.kwargs) > 0 else None, - (f'output_process_fn_grad={self.output_process_fn_grad}' - if self.output_process_fn_grad is not None else None), - f'broadcasts_input={self.broadcasts_input}'] + f'input={formatter(self.input)}', + f'args={formatter(self.args)}', + f'kwargs={formatter(self.kwargs)}', + f'output_process_fn_grad={self.output_process_fn_grad}', + f'broadcasts_input={self.broadcasts_input}', + f'name={repr(self.name)}'] return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + def __repr__(self): + return self._repr_helper(lambda x: x) + + def summary(self): + # Returns the SampleInput details in a more + # friendly format. + # It formats `Tensor` and `TensorList` + # in a more condensed representation. + def is_iter(arg): + try: + iter(arg) + return True + except TypeError as te: + return False + + def formatter(arg): + # Format any instance of `Tensor` (standalone, in list, or in dict) + # by Tensor[TensorShape] + # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] + if isinstance(arg, torch.Tensor): + shape = str(tuple(arg.shape)).replace('(', '').replace(')', '') + return f"Tensor[{shape}]" + elif isinstance(arg, dict): + return {k: formatter(v) for k, v in arg.items()} + elif is_iterable_of_tensors(arg): + return "TensorList[" + ", ".join(map(formatter, arg)) + "]" + elif is_iter(arg): # Handle list, tuple or any iterable type + return "(" + ",".join(map(formatter, arg)) + ")" + + return repr(arg) + + return self._repr_helper(formatter) + + class AliasInfo(object): """Class holds alias information. For example, torch.abs -> torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ @@ -2003,6 +2043,24 @@ def wrapped_fn(x): return wrapped_fn +def np_unary_ufunc_integer_promotion_wrapper_with_astype(fn): + # Check np_unary_ufunc_integer_promotion_wrapper + def is_integral(dtype): + return dtype in [np.bool_, bool, np.uint8, np.int8, np.int16, np.int32, np.int64] + + @wraps(fn) + def wrapped_fn(x): + # As the default dtype can change, acquire it when function is called. + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + if is_integral(x.dtype): + return fn(x).astype(np_dtype) + return fn(x) + + return wrapped_fn + + # Metadata class for Fast Fourier Transforms in torch.fft. class SpectralFuncInfo(OpInfo): """Operator information for torch.fft transforms. """ @@ -2913,6 +2971,24 @@ def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): return samples + [vec_sample] +def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # Shapes for 2D Tensors + shapes_2d = ((M, M), (3, 5), (5, 3)) + + # Shapes for 3D Tensors + shapes_3d = ((M, M, M),) + + args_2d = ((), (2,), (-2,), (1,)) + args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1)) + + def generator(): + for shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)): + yield SampleInput(make_arg(shape), args=arg) + + return list(generator()) + def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): low, high = op_info.domain @@ -3301,6 +3377,45 @@ def generator(): return list(generator()) + +def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if list_args: + cases = ( # type: ignore[assignment] + ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), + ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2),), + ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], -2),) + ) + else: + cases = ( # type: ignore[assignment] + ((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ) + + def generator(): + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + return list(generator()) + + +def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = (((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), + ((S, S, S), ([int(S / 3), S - int(S / 3), 0],)), + ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], 2)), + ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], -2)), + ) + + def generator(): + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + return list(generator()) + + def sample_inputs_msort(op_info, device, dtype, requires_grad): def apply_grad(t): if dtype in floating_types_and(torch.float16, torch.bfloat16): @@ -3465,6 +3580,26 @@ def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): return samples +def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + def generator(): + yield SampleInput(make_arg((S, M, S))) + yield SampleInput(make_arg(())) + + return list(generator()) + + +def sample_inputs_contiguous(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + def generator(): + yield SampleInput(make_arg((S, S))) + yield SampleInput(make_arg((S, S), noncontiguous=True)) + + return list(generator()) + + def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -3586,6 +3721,24 @@ def generator(): return list(generator()) +def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad) + + test_cases = [ + (_tensor((S, S, S)), (2,)), + (_tensor((S, S, S)), (2, 1,)), + (_tensor((S, S, S)), (2, -1,)), + (_tensor((S, S, S)), (2, 1, True,)), + (_tensor((S, S, S)), (2, -1, True,)), + (_tensor((S,)), (2, 0,)), + (_tensor((S,)), (2, 0, True,)), + (_tensor(()), (1,)), + (_tensor(()), (1, 0,)), + (_tensor(()), (1, 0, True)) + ] + + return [SampleInput(tensor, args=args) for tensor, args in test_cases] foreach_unary_op_db: List[OpInfo] = [ ForeachUnaryFuncInfo('exp'), @@ -3617,7 +3770,7 @@ def generator(): ForeachUnaryFuncInfo('ceil', dtypes=floating_types(), dtypesIfCPU=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half)), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16)), ForeachUnaryFuncInfo('erf', dtypes=floating_types(), @@ -3637,7 +3790,7 @@ def generator(): ForeachUnaryFuncInfo('floor', dtypes=floating_types(), dtypesIfCPU=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half)), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16)), ForeachUnaryFuncInfo('log1p', dtypes=floating_types(), @@ -3647,12 +3800,12 @@ def generator(): ForeachUnaryFuncInfo('round', dtypes=floating_types(), dtypesIfCPU=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half)), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16)), ForeachUnaryFuncInfo('frac', dtypes=floating_types(), dtypesIfCPU=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half)), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16)), ForeachUnaryFuncInfo('reciprocal', dtypes=floating_types(), @@ -3667,7 +3820,7 @@ def generator(): ForeachUnaryFuncInfo('trunc', dtypes=floating_types(), dtypesIfCPU=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half)), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16)), ForeachUnaryFuncInfo('abs', dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), @@ -3811,6 +3964,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # "rsqrt_cpu" not implemented for 'BFloat16' backward_dtypesIfCPU=all_types_and_complex_and(torch.bool), assert_autodiffed=True, + supports_forward_ad=True, decorators=(precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-1, torch.complex64: 1e-2}),), @@ -3839,6 +3993,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), supports_inplace_autograd=False, + supports_forward_ad=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -3889,6 +4044,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), assert_autodiffed=True, supports_inplace_autograd=False, + supports_forward_ad=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, sample_inputs_func=sample_inputs_addmm), OpInfo('addmm', @@ -3900,6 +4056,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), assert_autodiffed=True, supports_inplace_autograd=False, + supports_forward_ad=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1)), @@ -3910,6 +4067,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): *[torch.bfloat16] if CUDA11OrLater else []), dtypesIfROCM=floating_types_and(torch.half), supports_inplace_autograd=False, + supports_forward_ad=True, skips=( # issue may fix: https://github.com/pytorch/pytorch/issues/55589 # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. @@ -3923,6 +4081,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), dtypesIfROCM=floating_types_and(torch.half), + supports_forward_ad=True, skips=( # addbmm does not correctly warn when resizing out= inputs SkipInfo('TestCommon', 'test_out'), @@ -3988,6 +4147,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): backward_dtypesIfCUDA=all_types_and_complex_and(torch.bool), # Reference: https://github.com/pytorch/pytorch/issues/50747 supports_inplace_autograd=False, + supports_forward_ad=True, skips=( # Reference: https://github.com/pytorch/pytorch/issues/50747 SkipInfo('TestCommon', 'test_variant_consistency_eager', @@ -3998,6 +4158,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex(), dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), assert_autodiffed=True, + supports_forward_ad=True, supports_inplace_autograd=False, skips=( # TODO: update sample inputs with for_inplace_variant kwarg to support this test @@ -4007,6 +4168,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=floating_and_complex_types(), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), supports_inplace_autograd=False, + supports_forward_ad=True, skips=( # TODO: update sample inputs with for_inplace_variant kwarg to support this test SkipInfo('TestCommon', 'test_variant_consistency_eager'),), @@ -4030,6 +4192,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): ref=np.arcsin, domain=(-1, 1), supports_sparse=True, + supports_forward_ad=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), safe_casts_outputs=True, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), @@ -4060,6 +4223,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), supports_inplace_autograd=False, + supports_forward_ad=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -4073,6 +4237,10 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + # Complex gradcheck tests asinh at points 0 + ix for x > 1 which are points + # where asinh is not differentiable + SkipInfo('TestGradients', 'test_forward_mode_AD', + dtypes=complex_types()) )), UnaryUfuncInfo('atan', aliases=('arctan', ), @@ -4080,6 +4248,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, + supports_forward_ad=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), safe_casts_outputs=True, skips=( @@ -4114,6 +4283,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): safe_casts_outputs=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), supports_inplace_autograd=False, + supports_forward_ad=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -4143,7 +4313,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('ceil', ref=np.ceil, dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), assert_autodiffed=True), OpInfo('cholesky', dtypes=floating_and_complex_types(), @@ -4173,6 +4343,21 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), sample_inputs_func=sample_inputs_chunk, supports_out=False), + OpInfo('clone', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_clone, + supports_forward_ad=True, + supports_out=False), + OpInfo('contiguous', + op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_contiguous, + supports_forward_ad=True, + skips=( + # JIT has issue when op is passed as lambda + SkipInfo('TestCommon', 'test_variant_consistency_jit'), + ), + supports_out=False), OpInfo('symeig', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, @@ -4217,6 +4402,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): ref=np.conj, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + supports_forward_ad=True, skips=( # File "test_unary_ufuncs.py", line 289, in test_reference_numerics # if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype): @@ -4290,7 +4476,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): )), OpInfo('cumsum', dtypesIfCPU=all_types_and_complex_and(torch.bool), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), skips=( # "cumsum_out_{cpu,cuda}" not implemented for 'Bool' SkipInfo('TestOpInfo', 'test_supported_dtypes', @@ -4301,7 +4487,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): sample_inputs_func=sample_inputs_cumulative_ops), OpInfo('cumprod', dtypes=all_types_and_complex_and(torch.bool), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), skips=( # "cumprod_out_{cpu, cuda}" not implemented for 'Bool' SkipInfo('TestOpInfo', 'test_supported_dtypes', @@ -4315,12 +4501,12 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): gradcheck_fast_mode=False), OpInfo('cummax', dtypesIfCPU=all_types_and(torch.bool), - dtypesIfCUDA=all_types_and(torch.bool, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('cummin', dtypesIfCPU=all_types_and(torch.bool), - dtypesIfCUDA=all_types_and(torch.bool, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), UnaryUfuncInfo('deg2rad', @@ -4400,6 +4586,15 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCPU=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_diag), + OpInfo('diag_embed', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed), + OpInfo('diagonal', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + sample_inputs_func=sample_inputs_diagonal_diag_embed), OpInfo('eq', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), supports_autograd=False, @@ -4415,7 +4610,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('frac', ref=lambda x: np.modf(x)[0], dtypes=floating_types_and(torch.bfloat16, torch.float16), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), assert_autodiffed=True, # Reference for disabling extremals # https://github.com/pytorch/pytorch/issues/51948 @@ -4504,7 +4699,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('floor', ref=np.floor, dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), assert_autodiffed=True), OpInfo('flip', op=torch.flip, @@ -4522,11 +4717,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): sample_inputs_func=sample_inputs_fliplr_flipud, supports_out=False), UnaryUfuncInfo('i0', - ref=np.i0, + ref=np_unary_ufunc_integer_promotion_wrapper_with_astype( + scipy.special.i0) if TEST_SCIPY else _NOTHING, decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 5e-1}),), - dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + safe_casts_outputs=True, supports_autograd=False), UnaryUfuncInfo('special.i0e', aten_name='special_i0e', @@ -4620,6 +4817,10 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # cuda gradchecks are slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)), + OpInfo('kthvalue', + dtypes=all_types(), + dtypesIfCUDA=all_types_and(torch.float16), + sample_inputs_func=sample_inputs_kthvalue), OpInfo('le', aliases=('less_equal',), dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), @@ -4945,7 +5146,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): sample_inputs_func=sample_inputs_masked_select), OpInfo('matrix_exp', dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), sample_inputs_func=sample_inputs_matrix_exp, supports_out=False), OpInfo('matmul', @@ -5240,7 +5441,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('round', ref=np.round, dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), assert_autodiffed=True,), UnaryUfuncInfo('sin', ref=np.sin, @@ -5308,6 +5509,21 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', device_type='cpu', dtypes=[torch.complex64]) )), + OpInfo('split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_out=False, + assert_autodiffed=True), + OpInfo('split', + variant_test_name='list_args', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=partial(sample_inputs_split, list_args=True), + supports_out=False), + OpInfo('split_with_sizes', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_split_with_sizes, + supports_out=False, + assert_autodiffed=True), OpInfo('__radd__', op=torch.Tensor.__radd__, dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), @@ -5509,7 +5725,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): aliases=('fix', ), ref=np.trunc, dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), assert_autodiffed=True), UnaryUfuncInfo('exp2', aliases=('special.exp2', ), @@ -5639,6 +5855,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): decorators=(precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}),), safe_casts_outputs=True, + supports_forward_ad=True, supports_complex_to_float=True), OpInfo('linalg.solve', aten_name='linalg_solve', @@ -6153,12 +6370,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): ), OpInfo('logcumsumexp', dtypes=floating_types_and(), - dtypesIfCUDA=floating_types_and(torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), skips=( # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. SkipInfo('TestCommon', 'test_out', dtypes=(torch.float32,), device_type='cuda'), - # logcumsumexp_backward not implemented for 'Half - SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.float16,), device_type='cuda'), + # logcumsumexp_backward not implemented for Half & BFloat16 + SkipInfo('TestOpInfo', 'test_supported_backward', + dtypes=(torch.float16, torch.bfloat16), device_type='cuda'), ), sample_inputs_func=sample_inputs_logcumsumexp), UnaryUfuncInfo('sigmoid', @@ -6440,14 +6658,6 @@ def method_tests(): ('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'), ('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'), ('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'), - ('kthvalue', (S, S, S), (2,)), - ('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]), - ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]), - ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]), - ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]), - ('kthvalue', (), (1,), 'scalar', (), ()), - ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]), - ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]), ('median', (S, S, S), NO_ARGS), ('median', (S, S, S), (1,), 'dim', (), [0]), ('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), @@ -6510,23 +6720,6 @@ def method_tests(): ('norm', (), (3, 0), '3_dim_scalar', (), [1]), ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', (), [1]), ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', (), [1]), - ('clone', (S, M, S), NO_ARGS), - ('clone', (), NO_ARGS, 'scalar'), - ('contiguous', (S, S), NO_ARGS, '', (True,)), - ('contiguous', torch.randn(S, S).transpose(0, 1), NO_ARGS, 'not_contiguous', (True,)), - ('diag_embed', (S, S), NO_ARGS), - ('diagonal', (M, M), NO_ARGS, '2d'), - ('diagonal', (3, 5), NO_ARGS, '2d_wide'), - ('diagonal', (3, 5), (2,), '2d_wide_pos'), - ('diagonal', (3, 5), (-2,), '2d_wide_neg'), - ('diagonal', (5, 3), NO_ARGS, '2d_tall'), - ('diagonal', (5, 3), (2,), '2d_tall_pos'), - ('diagonal', (5, 3), (-2,), '2d_tall_neg'), - ('diagonal', (M, M), (1,), '2d_1'), - ('diagonal', (M, M), (2,), '2d_2'), - ('diagonal', (M, M, M), (1, 1, 2), '3d_1'), - ('diagonal', (M, M, M), (2, 0, 1), '3d_2'), - ('diagonal', (M, M, M), (-2, 0, 1), '3d_3'), ('tril', (M, M), NO_ARGS), ('tril', (M, M), (2,), 'idx'), ('tril', (S, M, M), NO_ARGS, 'batched'), @@ -6542,15 +6735,6 @@ def method_tests(): ('fill_', (S, S, S), (1,), 'number'), ('fill_', (), (1,), 'number_scalar'), ('fill_', (S, S, S), ((),), 'variable'), - ('split', (S, S, S), (2,), '', (True,)), - ('split', (S, S, S), (S, 1), 'dim', (True,), [1]), - ('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list', - (True, 'aten::split_with_sizes')), - ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', - (True, 'aten::split_with_sizes'), [1]), - ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), '', (True,)), - ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3), 0],), 'size_0', (True, )), - ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'dim', (True, ), [1]), ('tensor_split', (S, S, S), (3,), 'sections', (False,)), ('tensor_split', (S, S, S), (3, 1), 'sections_dim', (False,), [1]), ('tensor_split', (S, S, S), ([2, 4],), 'indices', (False,)), @@ -6566,7 +6750,8 @@ def method_tests(): ('where', (), (bernoulli_scalar(), ()), 'scalar', (True,)), ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask', (True,)), ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask', (True,)), - ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()) + ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), + ('to_sparse', (S, S), (1,), 'dim', (), (), [], lambda x: x.to_dense()) ] def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None): diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index da43c9f43b31c8..ad4a0390d45850 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -3268,6 +3268,17 @@ def fractional_max_pool3d_test(test_case): input_size=(5, 6, 7), reference_fn=lambda x, *_: x * torch.sigmoid(x), ), + dict( + module_name='Mish', + input_size=(), + desc='scalar', + reference_fn=lambda x, *_: x * torch.tanh(F.softplus(x)), + ), + dict( + module_name='Mish', + input_size=(5, 6, 7), + reference_fn=lambda x, *_: x * torch.tanh(F.softplus(x)), + ), dict( constructor=wrap_functional(F.softmax, dim=-1), cpp_options_args='F::SoftmaxFuncOptions(-1)', diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 550f3eb743e765..4714b19f067268 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -688,7 +688,19 @@ def _test_barrier_timeout(self, group_id, timeout): # Only execute barrier on rank == 0, causing it to timeout if local_rank == 0: expected_time = time.time() + timeout.total_seconds() - with self.assertRaisesRegex(Exception, " (Timed out|closed|timeout) "): + # In debug mode, we execute a monitored_barrier before the + # collective, so assert on that. + if dist._get_debug_mode() == dist._DistributedDebugLevel.DETAIL: + exception_ctx = self.assertRaisesRegex( + Exception, + "failed to pass monitoredBarrier" + ) + else: + exception_ctx = self.assertRaisesRegex( + Exception, + " (Timed out|closed|timeout) " + ) + with exception_ctx: dist.barrier(group_id) self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1) else: @@ -1733,13 +1745,18 @@ def call_dist_op( if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS: events = get_profiling_event(profiling_title_postfix, autograd_profiler_ctx) - self.assertEqual(len(events), len(op_calls)) + # DETAIL debug mode can use a pg wrapper that issues more collectives + # under the hood + if dist._get_debug_mode() != dist._DistributedDebugLevel.DETAIL: + self.assertEqual(len(events), len(op_calls)) for e in events: self.assertTrue(e.is_async) self.assertEqual(e.count, 1) self.assertGreaterEqual(e.cpu_time, 0) # Verify tensor shapes if given - if tensor_shapes is not None: + # DETAIL debug mode can use a pg wrapper that issues more collectives + # under the hood + if tensor_shapes is not None and dist._get_debug_mode() != dist._DistributedDebugLevel.DETAIL: self.assertEqual(e.input_shapes, tensor_shapes, f"event shape: {e.input_shapes} vs tensor {tensor_shapes}") # ALL REDUCE @@ -3330,8 +3347,8 @@ def test_DistributedDataParallel_non_default_stream(self): ) @unittest.skipIf( - BACKEND != "nccl", - "Only NCCL backend supports DDP communication hook", + BACKEND != "nccl" and BACKEND != "gloo", + "MPI backend does not support DDP communication hook on CUDA devices", ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm @@ -3389,7 +3406,12 @@ def test_ddp_comm_hook_logging(self): loss.backward() ddp_logging_data = ddp_model._get_ddp_logging_data() - self.assertEqual(ddp_logging_data.get("comm_hook"), None) + # Note: DETAIL debug mode logs DDP logging data to stdout and + # thus accesses std::map, which fills in a default value for the + # type if it didn't exist. + self.assertEqual( + ddp_logging_data.get("comm_hook", ""), "" + ) def _test_ddp_hook_parity(self, state, hook): rank = self.rank @@ -3441,8 +3463,8 @@ def _test_ddp_hook_parity(self, state, hook): ) @unittest.skipIf( - BACKEND != "nccl", - "Only NCCL backend supports DDP communication hook", + BACKEND != "nccl" and BACKEND != "gloo", + "MPI backend does not support DDP communication hook on CUDA devices", ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm @@ -3450,8 +3472,8 @@ def test_ddp_hook_parity_allreduce(self): self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook) @unittest.skipIf( - BACKEND != "nccl", - "Only NCCL backend supports DDP communication hook", + BACKEND != "nccl" and BACKEND != "gloo", + "MPI backend does not support DDP communication hook on CUDA devices", ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm @@ -3463,8 +3485,8 @@ def test_ddp_hook_parity_allreduce_process_group(self): self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook) @unittest.skipIf( - BACKEND != "nccl", - "Only NCCL backend supports DDP communication hook", + BACKEND != "nccl" and BACKEND != "gloo", + "MPI backend does not support DDP communication hook on CUDA devices", ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm @@ -4199,7 +4221,12 @@ def parse_env(var): self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None) self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None) # test runtime logging fields - self.assertEqual(ddp_logging_data.get("unused_parameter_size"), None) + # Note: DETAIL debug mode logs DDP logging data to stdout and + # thus accesses std::map, which fills in a default value for the + # type if it didn't exist. + self.assertEqual( + ddp_logging_data.get("unused_parameter_size", 0), 0 + ) self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1) self.assertEqual(ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)) # It is hard to test accurate latency, but it can test whether the latency is @@ -5816,10 +5843,12 @@ def test_ddp_model_diff_across_ranks(self): net = EmbeddingNet(self.rank) # When running with NCCL backend, we don't expect an error on rank 0, # rather, it will be taken down by NCCL_ASYNC_ERROR_HANDLING. When - # running with Gloo, we expect the error to be caught inline. + # running with Gloo or with debug mode wrapper, we expect the error + # to be caught inline. + is_detail_dbg_mode = dist._get_debug_mode() == dist._DistributedDebugLevel.DETAIL rank_0_ctx = ( suppress() - if dist.get_backend() == dist.Backend.NCCL + if dist.get_backend() == dist.Backend.NCCL and not is_detail_dbg_mode # Gloo can raise various exception messages, so just assert # Runtime error here. else self.assertRaises(RuntimeError) @@ -6041,9 +6070,13 @@ def test_monitored_barrier_gloo(self): # Kick off some allreduce work on all ranks for _ in range(10): dist.all_reduce(torch.cat(tensors)) - # Run monitored barrier + # Run monitored barrier and ensure it passees timeout = timedelta(seconds=2) dist.monitored_barrier(timeout=timeout) + # Check monitored_barrier success with wait_all_ranks=True + for _ in range(10): + dist.all_reduce(torch.cat(tensors)) + dist.monitored_barrier(timeout=timeout, wait_all_ranks=True) # All ranks besides 1 call into barrier, rank 0 should report failure # while others report gloo error. failed_rank = 1 @@ -6116,7 +6149,15 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): # practice, we don't need NCCL_BLOCKING_WAIT, but we use it in this # test to ensure it exits cleanly. if self.rank != 0: - with self.assertRaisesRegex(RuntimeError, "Caught collective operation timeout"): + # Can get different errors here depending on whether gloo-based + # wrapper PG is enabled or not, since with wrapper pg, it will + # fail in a collective synchronization check and not actually + # call into the nccl pg. + if dist._get_debug_mode() == dist._DistributedDebugLevel.DETAIL: + err_regex = "Timed out waiting" + else: + err_regex = "Caught collective operation timeout" + with self.assertRaisesRegex(RuntimeError, err_regex): nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1)) else: # Rank 0 should report first (in order) timed out rank or all ranks @@ -6450,7 +6491,10 @@ def test_ddp_sync_bn_training_vs_eval(self): # SyncBN allgathers stats across all ranks, so verify call to # all_gather in profiler. - all_gather_calls = get_profiling_event("all_gather", prof) + if BACKEND == 'nccl': + all_gather_calls = get_profiling_event("_all_gather_base", prof) + else: + all_gather_calls = get_profiling_event("all_gather", prof) self.assertNotEqual([], all_gather_calls) # Only do inference on one rank. If SyncBN did collective stats sync, @@ -6466,7 +6510,10 @@ def test_ddp_sync_bn_training_vs_eval(self): loss.backward() # Ensure sync does not occur in eval() mode. - all_gather_calls = get_profiling_event("all_gather", prof) + if BACKEND == 'nccl': + all_gather_calls = get_profiling_event("_all_gather_base", prof) + else: + all_gather_calls = get_profiling_event("all_gather", prof) self.assertEqual([], all_gather_calls) @skip_if_lt_x_gpu(2) diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index b4236f3224c3f7..3b5a441b7a5000 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -554,7 +554,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): def test_valid_device(self): if self.rank != 0: return - dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = dist_utils.worker_name(dst_rank) for remote_module in self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] @@ -565,6 +566,16 @@ def test_valid_device(self): self.assertEqual(device.type, "cuda") self.assertEqual(device.index, 0) + # Test rank works as well. + for remote_module in self._create_remote_module_iter( + "rank:{}/cuda:0".format(dst_rank), modes=[ModuleCreationMode.MODULE_CTOR] + ): + device = rpc.rpc_sync( + dst_worker_name, remote_device, (remote_module.module_rref,) + ) + self.assertEqual(device.type, "cuda") + self.assertEqual(device.index, 0) + @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_invalid_devices(self): @@ -614,7 +625,7 @@ def test_invalid_devices(self): ) with self.assertRaisesRegex( - RuntimeError, + ValueError, r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", ): list( @@ -626,8 +637,8 @@ def test_invalid_devices(self): ) with self.assertRaisesRegex( - RuntimeError, - r"The workername in remote_device '/' cannot be empty. The valid format is '/'", + ValueError, + r"Could not parse remote_device: /. The valid format is '/'", ): list( m.forward() for m in @@ -638,8 +649,8 @@ def test_invalid_devices(self): ) with self.assertRaisesRegex( - RuntimeError, - r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '/'", + ValueError, + r"Could not parse remote_device: /cuda:0. The valid format is '/'", ): list( m.forward() for m in diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index 45740f8bcfb95e..098f2657a1611e 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -1308,7 +1308,10 @@ def test_async_function_simple(self): @dist_init def test_async_function_wrong_return_type(self): - with self.assertRaisesRegex(RuntimeError, "Expected Future but got Tensor"): + with self.assertRaisesRegex( + RuntimeError, + "Async functions must return an IValue of Future type, but got Tensor", + ): rpc.rpc_sync( worker_name((self.rank + 1) % self.world_size), async_wrong_type ) @@ -1368,5 +1371,8 @@ def test_async_function_wrong_return_type_remote(self): worker_name((self.rank + 1) % self.world_size), async_wrong_type ) - with self.assertRaisesRegex(RuntimeError, "Expected Future but got Tensor"): + with self.assertRaisesRegex( + RuntimeError, + "Async functions must return an IValue of Future type, but got Tensor", + ): rref.to_here() diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 825a4c68ee9d7d..59d425de8a202e 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -4751,6 +4751,10 @@ def _test_rref_get_type_timeout(self, blocking): with self.assertRaisesRegex(RuntimeError, expected_err): fut.wait() + # FIXME We wait until the remote completed creating the OwnerRRef + # because there's currently a race if we shut down RPC before that. + slow_rref.to_here() + def test_rref_get_type_timeout_blocking(self): self._test_rref_get_type_timeout(blocking=True) @@ -4794,6 +4798,10 @@ def _test_rref_proxy_timeout(self, rref_proxy_api): with self.assertRaisesRegex(RuntimeError, expected_error): rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2)) + # FIXME We wait until the remote completed creating the OwnerRRef + # because there's currently a race if we shut down RPC before that. + slow_rref.to_here() + @dist_init def test_rref_proxy_timeout(self): for rpc_api in ["rpc_sync", "rpc_async", "remote"]: diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 50e5ceca4a6d0b..f93f32a9b2321c 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -357,7 +357,8 @@ def traced_fn(*inputs, **kwargs): 'test_nn_fold', # jit doesn't support sparse tensors. - 'test_to_sparse' + 'test_to_sparse', + 'test_to_sparse_dim', } # generates a script function and set of example inputs diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index e0e6e906a35d71..428196eac8b296 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -259,7 +259,7 @@ def timeit(self, number: int = 1000000) -> common.Measurement: """ with common.set_torch_threads(self._task_spec.num_threads): # Warmup - self._timer.timeit(number=max(int(number // 100), 1)) + self._timer.timeit(number=max(int(number // 100), 2)) return common.Measurement( number_per_run=number, @@ -443,7 +443,7 @@ def collect_callgrind( jitter from the Python interpreter.) This makes them ideal for detailed performance analysis. This method runs `stmt` in a separate process so that Valgrind can instrument the program. Performance is severely - degraded due to the instrumentation, howevever this is ameliorated by + degraded due to the instrumentation, however this is ameliorated by the fact that a small number of iterations is generally sufficient to obtain good measurements. diff --git a/torch/utils/data/_typing.py b/torch/utils/data/_typing.py index 27be0884771ebc..7f8f51eb22e7b8 100644 --- a/torch/utils/data/_typing.py +++ b/torch/utils/data/_typing.py @@ -4,7 +4,7 @@ import collections import numbers import sys -from typing import (Any, Dict, Iterator, List, Set, Tuple, TypeVar, Union, +from typing import (Any, Dict, Iterator, Generic, List, Set, Tuple, TypeVar, Union, get_type_hints) from typing import _eval_type, _tp_cache, _type_check, _type_repr # type: ignore[attr-defined] @@ -17,11 +17,12 @@ # Please check [Note: TypeMeta and TypeAlias] try: from typing import GenericMeta # Python 3.6 + _GenericAlias = GenericMeta except ImportError: # Python > 3.6 # In case of metaclass conflict due to ABCMeta or _ProtocolMeta # For Python 3.9, only Protocol in typing uses metaclass from abc import ABCMeta - from typing import _ProtocolMeta # type: ignore[attr-defined] + from typing import _ProtocolMeta, _GenericAlias # type: ignore[attr-defined, no-redef] class GenericMeta(_ProtocolMeta, ABCMeta): # type: ignore[no-redef] pass @@ -65,6 +66,10 @@ def issubtype(left, right, recursive=True): if right is Any or left == right: return True + if isinstance(right, _GenericAlias): + if getattr(right, '__origin__', None) is Generic: + return True + if right == type(None): return False @@ -220,13 +225,16 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, _DataPipeType): - return self.issubtype(other) and other.issubtype(self) + return self.param == other.param return NotImplemented def __hash__(self): return hash(self.param) def issubtype(self, other): + if isinstance(other.param, _GenericAlias): + if getattr(other.param, '__origin__', None) is Generic: + return True if isinstance(other, _DataPipeType): return issubtype(self.param, other.param) if isinstance(other, type): @@ -238,7 +246,8 @@ def issubtype_of_instance(self, other): # Default type for DataPipe without annotation -_DEFAULT_TYPE = _DataPipeType(Any) +T_co = TypeVar('T_co', covariant=True) +_DEFAULT_TYPE = _DataPipeType(Generic[T_co]) class _DataPipeMeta(GenericMeta): @@ -248,44 +257,66 @@ class _DataPipeMeta(GenericMeta): """ type: _DataPipeType - def __new__(cls, name, bases, namespace, **kargs): + def __new__(cls, name, bases, namespace, **kwargs): # For Python > 3.6 cls.__origin__ = None # Need to add _is_protocol for Python 3.7 _ProtocolMeta if '_is_protocol' not in namespace: namespace['_is_protocol'] = True if 'type' in namespace: - return super().__new__(cls, name, bases, namespace) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] namespace['__type_class__'] = False # For plain derived class without annotation for base in bases: if isinstance(base, _DataPipeMeta): - return super().__new__(cls, name, bases, namespace) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] namespace.update({'type': _DEFAULT_TYPE, '__init_subclass__': _dp_init_subclass}) - return super().__new__(cls, name, bases, namespace) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + def __init__(self, name, bases, namespace, **kwargs): + super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload] @_tp_cache - def __getitem__(self, param): - if param is None: + def __getitem__(self, params): + if params is None: raise TypeError('{}[t]: t can not be None'.format(self.__name__)) - if isinstance(param, str): - param = ForwardRef(param) - if isinstance(param, tuple): - param = Tuple[param] - _type_check(param, msg="{}[t]: t must be a type".format(self.__name__)) - t = _DataPipeType(param) + if isinstance(params, str): + params = ForwardRef(params) + if not isinstance(params, tuple): + params = (params, ) + + msg = "{}[t]: t must be a type".format(self.__name__) + params = tuple(_type_check(p, msg) for p in params) + + if isinstance(self.type.param, _GenericAlias): + orig = getattr(self.type.param, '__origin__', None) + if isinstance(orig, type) and orig is not Generic: + p = self.type.param[params] # type: ignore[index] + t = _DataPipeType(p) + l = len(str(self.type)) + 2 + name = self.__name__[:-l] + name = name + '[' + str(t) + ']' + bases = (self,) + self.__bases__ + return self.__class__(name, bases, + {'__init_subclass__': _dp_init_subclass, + 'type': t, + '__type_class__': True}) + + if len(params) > 1: + raise TypeError('Too many parameters for {} actual {}, expected 1'.format(self, len(params))) + + t = _DataPipeType(params[0]) if not t.issubtype(self.type): raise TypeError('Can not subclass a DataPipe[{}] from DataPipe[{}]' .format(t, self.type)) # Types are equal, fast path for inheritance - if self.type.issubtype(t): - if _mro_subclass_init(self): - return self + if self.type == t: + return self name = self.__name__ + '[' + str(t) + ']' bases = (self,) + self.__bases__ @@ -307,24 +338,6 @@ def __hash__(self): return hash((self.__name__, self.type)) -def _mro_subclass_init(obj): - r""" - Run through MRO to check if any super class has already built in - the corresponding `__init_subclass__`. If so, no need to add - `__init_subclass__`. - """ - - mro = obj.__mro__ - for b in mro: - if isinstance(b, _DataPipeMeta): - if b.__init_subclass__ == _dp_init_subclass: - return True - if hasattr(b.__init_subclass__, '__func__') and \ - b.__init_subclass__.__func__ == _dp_init_subclass: # type: ignore[attr-defined] - return True - return False - - def _dp_init_subclass(sub_cls, *args, **kwargs): # TODO: # - add global switch for type checking at compile-time