From c8679b8f370067e8ea534adde25db263eac8984f Mon Sep 17 00:00:00 2001 From: wangjiangben-hw Date: Tue, 11 Oct 2022 15:53:36 +0800 Subject: [PATCH 01/17] init npu --- mmengine/device/__init__.py | 4 ++-- mmengine/device/utils.py | 15 +++++++++++++-- mmengine/dist/dist.py | 11 ++++++++--- mmengine/dist/utils.py | 15 +++++++++++++-- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 9 +++++++-- mmengine/optim/optimizer/builder.py | 5 +++++ mmengine/runner/amp.py | 7 +++++-- tests/test_device/test_device.py | 6 ++++-- 8 files changed, 57 insertions(+), 15 deletions(-) diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index 0524e6b937..c6b9d0afdf 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .utils import (get_device, get_max_cuda_memory, is_cuda_available, - is_mlu_available, is_mps_available) + is_mlu_available, is_mps_available, is_npu_available) __all__ = [ 'get_max_cuda_memory', 'get_device', 'is_cuda_available', - 'is_mlu_available', 'is_mps_available' + 'is_mlu_available', 'is_mps_available', 'is_npu_available' ] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index c1819cbaae..1a31af549b 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -32,6 +32,15 @@ def is_cuda_available() -> bool: return torch.cuda.is_available() +def is_npu_available() -> bool: + """Returns True if Ascend PyTorch and npu devices exist.""" + try: + import torch_npu # noqa: F401 + except Exception: + return False + return hasattr(torch, 'npu') and torch.npu.is_available() + + def is_mlu_available() -> bool: """Returns True if Cambricon PyTorch and mlu devices exist.""" return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() @@ -49,9 +58,11 @@ def get_device() -> str: """Returns the currently existing device type. Returns: - str: cuda | mlu | mps | cpu. + str: cuda | npu | mlu | mps | cpu. """ - if is_cuda_available(): + if is_npu_available(): + return 'npu' + elif is_cuda_available(): return 'cuda' elif is_mlu_available(): return 'mlu' diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index bc63dd2726..af1e14aa49 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -20,6 +20,7 @@ get_comm_device, cast_data_device) from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION +from mmengine.device import is_npu_available def _get_reduce_op(name: str) -> torch_dist.ReduceOp: @@ -411,7 +412,10 @@ def _broadcast_object_list(object_list: List[Any], group_backend = get_backend(group) is_nccl_backend = group_backend == torch_dist.Backend.NCCL current_device = torch.device('cpu') - if is_nccl_backend: + if is_npu_available(): + current_device = torch.npu.current_device() + object_sizes_tensor = object_sizes_tensor.to(current_device) + elif is_nccl_backend: # See note about using torch.cuda.current_device() here in # docstring. We cannot simply use my_rank since rank == device is # not necessarily true. @@ -430,7 +434,7 @@ def _broadcast_object_list(object_list: List[Any], dtype=torch.uint8, ) - if is_nccl_backend: + if is_nccl_backend or is_npu_available(): object_tensor = object_tensor.to(current_device) torch_dist.broadcast(object_tensor, src=src, group=group) # Deserialize objects using their stored sizes. @@ -504,7 +508,8 @@ def broadcast_object_list(data: List[Any], if group is None: group = get_default_group() - if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + if digit_version(TORCH_VERSION) >= digit_version( + '1.8.0') and not is_npu_available(): torch_dist.broadcast_object_list(data, src, group) else: _broadcast_object_list(data, src, group) diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index 285c0f371b..f5ec3cf51d 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -10,7 +10,7 @@ from torch import Tensor from torch import distributed as torch_dist from torch.distributed import ProcessGroup -from mmengine.device import is_mlu_available +from mmengine.device import is_mlu_available, is_npu_available from collections.abc import Iterable, Mapping @@ -80,6 +80,14 @@ def _init_dist_pytorch(backend, **kwargs) -> None: rank=rank, world_size=int(os.environ['WORLD_SIZE']), **kwargs) + elif is_npu_available(): + import torch_npu # noqa: F401 + torch.npu.set_device(rank) + torch_dist.init_process_group( + backend='hccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) else: num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) @@ -437,7 +445,10 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: torch.device: The device of backend. """ backend = get_backend(group) - if backend == torch_dist.Backend.NCCL: + if backend == 'hccl': + import torch_npu # noqa: F401 + return torch.device('npu', torch.npu.current_device()) + elif backend == torch_dist.Backend.NCCL: return torch.device('cuda', torch.cuda.current_device()) elif backend == 'cncl': import torch_mlu # noqa: F401 diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 730c28344d..db135799ce 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -3,13 +3,18 @@ import torch import torch.nn as nn -from torch.cuda.amp import GradScaler +from mmengine.device import is_npu_available from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from .optimizer_wrapper import OptimWrapper +if is_npu_available(): + from torch.npu.amp import GradScaler +else: + from torch.cuda.amp import GradScaler + @OPTIM_WRAPPERS.register_module() class AmpOptimWrapper(OptimWrapper): @@ -44,7 +49,7 @@ class AmpOptimWrapper(OptimWrapper): def __init__(self, loss_scale='dynamic', **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') - assert torch.cuda.is_available(), ( + assert torch.cuda.is_available() or is_npu_available(), ( '``AmpOptimizerWrapper`` is only available training on gpu') super().__init__(**kwargs) self._scale_update_param = None diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 01f26a990d..898161ee01 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -7,6 +7,7 @@ import torch.nn as nn from mmengine.config import Config, ConfigDict +from mmengine.device import is_npu_available from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS from .optimizer_wrapper import OptimWrapper @@ -53,6 +54,10 @@ def build_optim_wrapper(model: nn.Module, constructor_type = optim_wrapper_cfg.pop('constructor', 'DefaultOptimWrapperConstructor') paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) + + if is_npu_available(): + optim_wrapper_cfg['type'] = 'AmpOptimWrapper' + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( dict( type=constructor_type, diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index f09437b1e4..d8f2aca638 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -5,7 +5,7 @@ import torch -from mmengine.device import get_device +from mmengine.device import get_device, is_npu_available from mmengine.logging import print_log from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -86,7 +86,10 @@ def autocast(device_type: Optional[str] = None, logger='current', level=logging.WARNING) - if torch.cuda.is_available(): + if is_npu_available(): + with torch.npu.amp.autocast(enabled=enabled): + yield + elif torch.cuda.is_available(): with torch.cuda.amp.autocast(enabled=enabled): yield else: diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py index 1c41721bf4..19bd1f7f19 100644 --- a/tests/test_device/test_device.py +++ b/tests/test_device/test_device.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.device import (get_device, is_cuda_available, is_mlu_available, - is_mps_available) + is_mps_available, is_npu_available) def test_get_device(): device = get_device() - if is_cuda_available(): + if is_npu_available(): + assert device == 'npu' + elif is_cuda_available(): assert device == 'cuda' elif is_mlu_available(): assert device == 'mlu' From 0affc3930487cbc51087cd744c27608b5b6e0de5 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Wed, 12 Oct 2022 16:24:45 +0800 Subject: [PATCH 02/17] Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index db135799ce..acb1be6d05 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from mmengine.device import is_npu_available +from mmengine.device import is_npu_available, is_cuda_available from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION From 2d574011d54f5c20274a2b572235670a039e935a Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Wed, 12 Oct 2022 16:24:52 +0800 Subject: [PATCH 03/17] Update mmengine/dist/dist.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/dist/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index af1e14aa49..95945987c7 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -434,7 +434,7 @@ def _broadcast_object_list(object_list: List[Any], dtype=torch.uint8, ) - if is_nccl_backend or is_npu_available(): + if is_nccl_backend or is_hccl_backend: object_tensor = object_tensor.to(current_device) torch_dist.broadcast(object_tensor, src=src, group=group) # Deserialize objects using their stored sizes. From 296e331b7b41b97cae34977df88163deab6e05ea Mon Sep 17 00:00:00 2001 From: wangjiangben-hw Date: Wed, 12 Oct 2022 16:30:34 +0800 Subject: [PATCH 04/17] change to is_hccl_backend --- mmengine/dist/dist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index 95945987c7..3e320e3848 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -412,7 +412,8 @@ def _broadcast_object_list(object_list: List[Any], group_backend = get_backend(group) is_nccl_backend = group_backend == torch_dist.Backend.NCCL current_device = torch.device('cpu') - if is_npu_available(): + is_hccl_backend = group_backend == 'hccl' + if is_hccl_backend: current_device = torch.npu.current_device() object_sizes_tensor = object_sizes_tensor.to(current_device) elif is_nccl_backend: From f78c0ddc5e4ea83024a9ff03c7441160b18832d0 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Wed, 12 Oct 2022 16:32:32 +0800 Subject: [PATCH 05/17] Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index acb1be6d05..ceb78971e9 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -49,7 +49,7 @@ class AmpOptimWrapper(OptimWrapper): def __init__(self, loss_scale='dynamic', **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') - assert torch.cuda.is_available() or is_npu_available(), ( + assert is_cuda_available() or is_npu_available(), ( '``AmpOptimizerWrapper`` is only available training on gpu') super().__init__(**kwargs) self._scale_update_param = None From 10d2655cb1c9be0c0c44005bd30c5d7d92608832 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw Date: Wed, 12 Oct 2022 19:19:11 +0800 Subject: [PATCH 06/17] add comment with AmpOptimWrapper --- mmengine/optim/optimizer/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 898161ee01..8eae0e5429 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -55,6 +55,9 @@ def build_optim_wrapper(model: nn.Module, 'DefaultOptimWrapperConstructor') paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) + # Since the current generation of NPU(Ascend 910) only supports + # mixed precision training, here we turn on mixed precision by default + # on the NPU to make the training normal if is_npu_available(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' From 980dbd722fe2c5051f5390b4b9a0af6ce90cc6b3 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Wed, 12 Oct 2022 19:22:55 +0800 Subject: [PATCH 07/17] Update mmengine/runner/amp.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/runner/amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index d8f2aca638..2ccc30734c 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -5,7 +5,7 @@ import torch -from mmengine.device import get_device, is_npu_available +from mmengine.device import get_device, is_npu_available, is_cuda_available from mmengine.logging import print_log from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION From 9787ee22aec3f102534ea170efb498973c6cfbec Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Wed, 12 Oct 2022 19:23:10 +0800 Subject: [PATCH 08/17] Update mmengine/runner/amp.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/runner/amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 2ccc30734c..2e0837b204 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -89,7 +89,7 @@ def autocast(device_type: Optional[str] = None, if is_npu_available(): with torch.npu.amp.autocast(enabled=enabled): yield - elif torch.cuda.is_available(): + elif is_cuda_available(): with torch.cuda.amp.autocast(enabled=enabled): yield else: From 2a9d98399009e68335003c6073cf6d1733a8d54f Mon Sep 17 00:00:00 2001 From: wangjiangben-hw Date: Wed, 12 Oct 2022 20:32:16 +0800 Subject: [PATCH 09/17] add npu fn in base_model --- mmengine/model/base_model/base_model.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index c2f9a5b9ce..8899d28dfe 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -210,6 +210,25 @@ def cuda( self._set_device(torch.device(device)) return super().cuda(device) + def npu( + self, + device: Optional[Union[int, str, torch.device]] = None, + ) -> nn.Module: + """Overrides this method to call :meth:`BaseDataPreprocessor.npu` + additionally. + + Returns: + nn.Module: The model itself. + + Note: + This generation of NPU(Ascend910) does not support + the use of multiple cards in a single process, + so the index here needs to be consistent with the default device + """ + device = torch.npu.current_device() + self._set_device(device) + return super().npu() + def cpu(self, *args, **kwargs) -> nn.Module: """Overrides this method to call :meth:`BaseDataPreprocessor.cpu` additionally. From dfd83bfe4683514f2b23c5284ca5832300757501 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Thu, 13 Oct 2022 14:54:43 +0800 Subject: [PATCH 10/17] Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index ceb78971e9..65e29dbe7e 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -50,7 +50,7 @@ def __init__(self, loss_scale='dynamic', **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') assert is_cuda_available() or is_npu_available(), ( - '``AmpOptimizerWrapper`` is only available training on gpu') + '``AmpOptimizerWrapper`` is only available training on gpu or npu') super().__init__(**kwargs) self._scale_update_param = None if loss_scale == 'dynamic': From 9207e0cef1fdfb3b7e720e53d07f12b7b7b26385 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw Date: Thu, 13 Oct 2022 15:06:49 +0800 Subject: [PATCH 11/17] clean lint --- mmengine/runner/amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 2e0837b204..8f99890944 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -5,7 +5,7 @@ import torch -from mmengine.device import get_device, is_npu_available, is_cuda_available +from mmengine.device import get_device, is_cuda_available, is_npu_available from mmengine.logging import print_log from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION From 5b23b3074bf6474ba456d9dc5a41e5b36b40e9af Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Thu, 13 Oct 2022 15:44:48 +0800 Subject: [PATCH 12/17] Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 65e29dbe7e..66f1255643 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from mmengine.device import is_npu_available, is_cuda_available +from mmengine.device import is_cuda_available, is_npu_available from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION From 2138438f355c2eb7ea3b52d923aac93e286bd512 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Mon, 17 Oct 2022 10:47:43 +0800 Subject: [PATCH 13/17] Update mmengine/model/base_model/base_model.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> --- mmengine/model/base_model/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 8899d28dfe..525a2f0664 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -212,7 +212,7 @@ def cuda( def npu( self, - device: Optional[Union[int, str, torch.device]] = None, + device: Union[int, str, torch.device, None] = None, ) -> nn.Module: """Overrides this method to call :meth:`BaseDataPreprocessor.npu` additionally. From 64d6d39e3e40c94552772af77dac0b400acdd620 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw Date: Mon, 17 Oct 2022 10:58:35 +0800 Subject: [PATCH 14/17] add is_npu_available --- docs/en/api/device.rst | 1 + docs/zh_cn/api/device.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/api/device.rst b/docs/en/api/device.rst index a694f823fc..4a16c73837 100644 --- a/docs/en/api/device.rst +++ b/docs/en/api/device.rst @@ -13,5 +13,6 @@ mmengine.device get_device get_max_cuda_memory is_cuda_available + is_npu_available is_mlu_available is_mps_available diff --git a/docs/zh_cn/api/device.rst b/docs/zh_cn/api/device.rst index a694f823fc..4a16c73837 100644 --- a/docs/zh_cn/api/device.rst +++ b/docs/zh_cn/api/device.rst @@ -13,5 +13,6 @@ mmengine.device get_device get_max_cuda_memory is_cuda_available + is_npu_available is_mlu_available is_mps_available From de6826236a6b0701f8d570d0a1b8edd361c6f514 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 17 Oct 2022 17:49:16 +0800 Subject: [PATCH 15/17] try to fix --- .github/workflows/merge_stage_test.yml | 2 +- .github/workflows/pr_stage_test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index 9957e2f9b1..b4cf7c0f8d 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -159,7 +159,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Upgrade pip - run: pip install pip --upgrade + run: python -m pip install pip --upgrade - name: Install PyTorch run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html - name: Build MMEngine from source diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index a747df2234..e89cca1c25 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -108,7 +108,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Upgrade pip - run: pip install pip --upgrade + run: python -m pip install pip --upgrade - name: Install PyTorch run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html - name: Build MMEngine from source From 40ebd52546295ff4a48138e2123812f4f45c775a Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 18 Oct 2022 00:43:36 +0800 Subject: [PATCH 16/17] Add comments --- .github/workflows/merge_stage_test.yml | 1 + .github/workflows/pr_stage_test.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index b4cf7c0f8d..71411b2a6d 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -159,6 +159,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Upgrade pip + # Windows CI could fail directly call `pip install pip --upgrade` run: python -m pip install pip --upgrade - name: Install PyTorch run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index e89cca1c25..1557f41590 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -108,6 +108,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Upgrade pip + # Windows CI could fail directly call `pip install pip --upgrade run: python -m pip install pip --upgrade - name: Install PyTorch run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html From 915d00c6813ca7009cbba856aa6fd226a46196a7 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 18 Oct 2022 00:45:39 +0800 Subject: [PATCH 17/17] Refine grammar --- .github/workflows/merge_stage_test.yml | 2 +- .github/workflows/pr_stage_test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index 71411b2a6d..ef62b1a738 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -159,7 +159,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Upgrade pip - # Windows CI could fail directly call `pip install pip --upgrade` + # Windows CI could fail If we call `pip install pip --upgrade` directly. run: python -m pip install pip --upgrade - name: Install PyTorch run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index 1557f41590..332d206840 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -108,7 +108,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Upgrade pip - # Windows CI could fail directly call `pip install pip --upgrade + # Windows CI could fail If we call `pip install pip --upgrade` directly. run: python -m pip install pip --upgrade - name: Install PyTorch run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html