Skip to content

Commit

Permalink
fix error and add dtype (mindspore-lab#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
WarruzuEndo authored Jun 7, 2023
1 parent b4c1cec commit 6f8782f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
21 changes: 13 additions & 8 deletions mindnlp/parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import Optional, Union

import mindspore
from mindspore import nn, ops
from mindspore import Parameter, Tensor
from mindspore.common.initializer import Initializer, initializer
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
embedding_size: int,
padding_idx: Optional[int] = None,
init_method: Union[str, Initializer] = "normal",
dtype: mindspore.dtype = mindspore.float32,
) -> None:
super().__init__()
# Keep the input dimensions.
Expand All @@ -68,7 +70,7 @@ def __init__(
# Allocate weights.
self.weight = Parameter(
initializer(
init_method, (self.vocab_size_per_partition, self.embedding_size)
init_method, (self.vocab_size_per_partition, self.embedding_size), dtype
),
"weight",
)
Expand All @@ -84,7 +86,7 @@ def construct(self, input_: Tensor) -> Tensor: # type: ignore
# Get the embeddings.
ori_shape = masked_input.shape
output_parallel = ops.gather(self.weight, masked_input.view(-1), 0).view(
ori_shape + (self.embedding_size)
ori_shape + (self.embedding_size, )
)
# Mask the output embedding.
output_parallel = ops.masked_fill(
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(
embedding_size: int,
padding_idx: Optional[int] = None,
init_method: Union[str, Initializer] = "normal",
dtype: mindspore.dtype = mindspore.float32,
) -> None:
super().__init__()
# Keep the input dimensions.
Expand All @@ -127,7 +130,7 @@ def __init__(
# Allocate weights.
self.weight = Parameter(
initializer(
init_method, (self.vocab_size, self.embedding_size_per_partition)
init_method, (self.vocab_size, self.embedding_size_per_partition), dtype
),
"weight",
)
Expand All @@ -136,7 +139,7 @@ def construct(self, input_: Tensor) -> Tensor: # type: ignore
input_parallel = copy_to_model_parallel_region(input_)
ori_shape = input_parallel.shape
output_parallel = ops.gather(self.weight, input_parallel.view(-1), 0).view(
ori_shape + (self.embedding_size)
ori_shape + (self.embedding_size_per_partition, )
)
output = gather_from_model_parallel_region(output_parallel)
return output
Expand Down Expand Up @@ -170,6 +173,7 @@ def __init__(
bias: bool = True,
gather_output: bool = True,
init_method: Union[str, Initializer] = "normal",
dtype: mindspore.dtype = mindspore.float32,
stride: int = 1,
keep_master_weight_for_test: bool = False,
) -> None:
Expand All @@ -188,14 +192,14 @@ def __init__(
# Parameters.
self.weight = Parameter(
initializer(
init_method, (self.in_features, self.output_size_per_partition)
init_method, (self.in_features, self.output_size_per_partition), dtype
),
"weight",
)
if bias:
# Always initialize bias to zero.
self.bias = Parameter(
initializer("zeros", (self.output_size_per_partition,)), "bias"
initializer("zeros", (self.output_size_per_partition,), dtype), "bias"
)
else:
self.bias = None
Expand Down Expand Up @@ -253,6 +257,7 @@ def __init__(
bias: bool = True,
input_is_parallel: bool = False,
init_method: Union[str, Initializer] = "normal",
dtype: mindspore.dtype = mindspore.float32,
stride: int = 1,
keep_master_weight_for_test: bool = False,
):
Expand All @@ -272,13 +277,13 @@ def __init__(
# we allocate the transpose.
self.weight = Parameter(
initializer(
init_method, (self.input_size_per_partition, self.out_features)
init_method, (self.input_size_per_partition, self.out_features), dtype
),
"weight",
)
if bias:
# Always initialize bias to zero.
self.bias = Parameter(initializer("zeros", (self.out_features,)), "bias")
self.bias = Parameter(initializer("zeros", (self.out_features,), dtype), "bias")
else:
self.bias = None

Expand Down
21 changes: 16 additions & 5 deletions mindnlp/parallel/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,26 @@
from mindspore import ops

from mindspore.communication import get_rank, get_group_size
from mindspore.ops import constexpr

from mindnlp._legacy.ops import AllGather
from .utils import divide_and_check_no_remainder, split_tensor_along_last_dim


@constexpr
def _get_rank():
return get_rank()


@constexpr
def _get_group_size():
return get_group_size()


def _reduce(input_: mindspore.Tensor) -> mindspore.Tensor:
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_group_size() == 1:
if _get_group_size() == 1:
return input_

# All-reduce.
Expand All @@ -41,22 +52,22 @@ def _split(input_: mindspore.Tensor) -> mindspore.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
# Bypass the function if we are using only 1 GPU.
if get_group_size() == 1:
if _get_group_size() == 1:
return input_

# Split along last dimension.
rank_size = get_group_size()
rank_size = _get_group_size()
input_list = split_tensor_along_last_dim(input_, rank_size)

rank = get_rank()
rank = _get_rank()
output = input_list[rank]

return output

def _gather(input_: mindspore.Tensor) -> mindspore.Tensor:
"""Gather tensors and concatinate along the last dimension."""
# Bypass the function if we are using only 1 GPU.
rank_size = get_group_size()
rank_size = _get_group_size()
if rank_size == 1:
return input_

Expand Down

0 comments on commit 6f8782f

Please sign in to comment.