Skip to content

Commit

Permalink
Update transformers ops unit tests to use requried_torch_version (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Dec 17, 2024
1 parent a964e43 commit 2f32966
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 18 deletions.
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_geglu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_gelu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally and using the tanh approximation
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_relu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
Expand Down
14 changes: 5 additions & 9 deletions tests/unit/ops/transformer/inference/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
from deepspeed.utils.torch import required_torch_version

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand All @@ -23,14 +22,11 @@ def allclose(x, y):


def version_appropriate_gelu(activations):
global torch_minor_version
if torch_minor_version is None:
torch_minor_version = int(torch.__version__.split('.')[1])
# If torch version = 1.12
if torch_minor_version < 12:
return torch.nn.functional.gelu(activations)
else:
# gelu behavior changes (correctly) in torch 1.12
if required_torch_version(min_version=1.12):
return torch.nn.functional.gelu(activations, approximate='tanh')
else:
return torch.nn.functional.gelu(activations)


def run_gelu_reference(activations):
Expand Down
1 change: 0 additions & 1 deletion tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None
torch_minor_version = None


def allclose(x, y):
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down

0 comments on commit 2f32966

Please sign in to comment.