Skip to content

Commit

Permalink
[BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Browse files Browse the repository at this point in the history
Pull Request resolved: #105436
Approved by: https://github.com/malfet, https://github.com/albanD
  • Loading branch information
justinchuby authored and pytorchmergebot committed Jul 21, 2023
1 parent 322dff4 commit 79c5e33
Show file tree
Hide file tree
Showing 89 changed files with 296 additions and 320 deletions.
8 changes: 4 additions & 4 deletions test/nn/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,13 +1539,13 @@ def test_conv3d_valid_padding_backward(self, device, dtype):
gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad)
gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad)

@parametrize_test("N", range(2, 4), name_fn=lambda N: 'ConvTranspose{}d'.format(N))
@parametrize_test("N", range(2, 4), name_fn=lambda N: f'ConvTranspose{N}d')
def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
# For inputs with no batch dim, verify output is the correct shape when output_size is set.
# See https://github.com/pytorch/pytorch/issues/75889
inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
ConvTransposeNd = getattr(nn, 'ConvTranspose{}d'.format(N))
ConvTransposeNd = getattr(nn, f'ConvTranspose{N}d')
m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device)
output = m(inp, output_size=output_size)
self.assertEqual(output.shape, output_size)
Expand Down Expand Up @@ -1892,9 +1892,9 @@ def test_conv_noncontig_weights(self, device):
w = w.expand([nc, int(nc / groups)] + list(w.shape))
w = w.detach().requires_grad_()
x = torch.randn([1, nc] + ([5] * dim), device=device, requires_grad=True)
y = getattr(F, 'conv{}d'.format(dim))(x, w, groups=groups)
y = getattr(F, f'conv{dim}d')(x, w, groups=groups)
y.sum().backward()
y = getattr(F, 'conv_transpose{}d'.format(dim))(x, w, groups=groups)
y = getattr(F, f'conv_transpose{dim}d')(x, w, groups=groups)
y.sum().backward()

def test_conv_noncontig_weights_and_bias(self, device):
Expand Down
2 changes: 1 addition & 1 deletion test/nn/test_module_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def kwarg_forward_hook(
return out


class DummyContextManager():
class DummyContextManager:
def __init__(self, inp):
self.input = inp

Expand Down
2 changes: 1 addition & 1 deletion test/nn/test_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
saved_kv=False, same_embed_dim=False,
average_attn_weights=average_attn_weights):
for _ in range(100):
batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
batch_sz, seq_len = (random.randint(2, 10) for r in range(2))
d_head = random.randint(3, 10)
nheads = random.randint(2, 5) * 2
d_model = d_head * nheads
Expand Down
16 changes: 8 additions & 8 deletions test/nn/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class TestPoolingNN(NNTestCase):
def test_adaptive_pooling_size_none(self):
for numel in (2, 3):
for pool_type in ('Max', 'Avg'):
cls_name = 'Adaptive{}Pool{}d'.format(pool_type, numel)
cls_name = f'Adaptive{pool_type}Pool{numel}d'
module_cls = getattr(nn, cls_name)
output_size = (2,) * (numel - 1) + (None,)
module = module_cls(output_size)
Expand Down Expand Up @@ -757,7 +757,7 @@ def test_adaptive_pooling_no_suppot_input(self, device, dtype):
# adapative_avg_pool2d for int is implemented
if numel == 2 and pool_type == 'Avg':
continue
cls_name = 'Adaptive{}Pool{}d'.format(pool_type, numel)
cls_name = f'Adaptive{pool_type}Pool{numel}d'
module_cls = getattr(nn, cls_name)
output_size = (2,) * numel
module = module_cls(output_size)
Expand Down Expand Up @@ -1323,7 +1323,7 @@ def test_maxpool_indices_no_batch_dim(self, device, dtype):
def test_max_pool_nan_inf(self, device, dtype):
for adaptive in ['', 'adaptive_']:
for num_dim in [1, 2, 3]:
fn_name = '{}max_pool{}d'.format(adaptive, num_dim)
fn_name = f'{adaptive}max_pool{num_dim}d'
fn = getattr(F, fn_name)

x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
Expand Down Expand Up @@ -1423,7 +1423,7 @@ def func(x):
@onlyNativeDeviceTypes # TODO: Fails on XLA
def test_fractional_max_pool_nan_inf(self, device, dtype):
for num_dim in [2, 3]:
fn_name = 'FractionalMaxPool{}d'.format(num_dim)
fn_name = f'FractionalMaxPool{num_dim}d'
fn = getattr(nn, fn_name)(kernel_size=2, output_size=1)
x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
res = fn(x)
Expand All @@ -1439,13 +1439,13 @@ def test_fractional_max_pool_nan_inf(self, device, dtype):
def test_pooling_zero_stride(self, device):
for op in ('max', 'avg'):
for num_dim in [1, 2, 3]:
fn_name = '{}_pool{}d'.format(op, num_dim)
fn_name = f'{op}_pool{num_dim}d'
fn = getattr(F, fn_name)
x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
lambda: fn(x, kernel_size=2, stride=0))

fn_module_name = '{}Pool{}d'.format(op.title(), num_dim)
fn_module_name = f'{op.title()}Pool{num_dim}d'
fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
lambda: fn_module(x))
Expand All @@ -1456,7 +1456,7 @@ def test_pooling_zero_stride(self, device):
def test_pool_large_size(self, device, dtype):
for op in ('max', 'avg'):
for num_dim in [1, 2, 3]:
fn_name = '{}_pool{}d'.format(op, num_dim)
fn_name = f'{op}_pool{num_dim}d'
fn = getattr(F, fn_name)
# 16777217 is the smallest integer not expressible in float32
x = torch.ones([1, 1, 16777217] + (num_dim - 1) * [1],
Expand Down Expand Up @@ -1486,7 +1486,7 @@ def helper(pool):
def test_pool_invalid_size(self, device, dtype):
for op in ('max', 'avg'):
for num_dim in [1, 2, 3]:
fn_name = '{}_pool{}d'.format(op, num_dim)
fn_name = f'{op}_pool{num_dim}d'
if op == 'max':
# New implementation without indices supports empty tensors
# TODO(Heitor) change once with_indices code is updated
Expand Down
18 changes: 9 additions & 9 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def forward(self, x):
class TestExportOptionsAPI(common_utils.TestCase):
def test_opset_version_default(self):
options = ResolvedExportOptions(None)
self.assertEquals(options.opset_version, _DEFAULT_OPSET_VERSION)
self.assertEqual(options.opset_version, _DEFAULT_OPSET_VERSION)

def test_opset_version_explicit(self):
options = ResolvedExportOptions(ExportOptions(opset_version=3000))
self.assertEquals(options.opset_version, 3000)
self.assertEqual(options.opset_version, 3000)

def test_raise_on_invalid_argument_type(self):
expected_exception_type = roar.BeartypeException
Expand All @@ -60,12 +60,12 @@ def test_dynamic_shapes_explicit(self):

def test_logger_default(self):
options = ResolvedExportOptions(None)
self.assertEquals(options.logger, logging.getLogger().getChild("torch.onnx"))
self.assertEqual(options.logger, logging.getLogger().getChild("torch.onnx"))

def test_logger_explicit(self):
options = ResolvedExportOptions(ExportOptions(logger=logging.getLogger()))
self.assertEquals(options.logger, logging.getLogger())
self.assertNotEquals(options.logger, logging.getLogger().getChild("torch.onnx"))
self.assertEqual(options.logger, logging.getLogger())
self.assertNotEqual(options.logger, logging.getLogger().getChild("torch.onnx"))


class TestDynamoExportAPI(common_utils.TestCase):
Expand Down Expand Up @@ -111,8 +111,8 @@ def serialize(
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)
with open(path) as fp:
self.assertEqual(fp.read(), expected_buffer)

def test_save_to_file_using_specified_serializer_without_inheritance(self):
expected_buffer = "I am not actually ONNX"
Expand All @@ -130,8 +130,8 @@ def serialize(
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path, "r") as fp:
self.assertEquals(fp.read(), expected_buffer)
with open(path) as fp:
self.assertEqual(fp.read(), expected_buffer)

def test_save_sarif_log_to_file_with_successful_export(self):
with common_utils.TemporaryFileName() as path:
Expand Down
6 changes: 3 additions & 3 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def forward(
def test_logit(self):
class Logit(torch.nn.Module):
def __init__(self, eps):
super(Logit, self).__init__()
super().__init__()
self.eps = eps

def forward(self, x):
Expand Down Expand Up @@ -4126,7 +4126,7 @@ def forward(self, input, indices, src):
def test_scatter_reduce(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
super().__init__()

def forward(self, x, index, input):
y_max = input.scatter_reduce(0, index, x, reduce="amax")
Expand All @@ -4148,7 +4148,7 @@ def forward(self, x, index, input):
def test_scatter_reduce_self_rank_zero(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
super().__init__()

def forward(self, x, index, input):
y_max = input.scatter_reduce(0, index, x, reduce="amax")
Expand Down
4 changes: 1 addition & 3 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import textwrap
import ctypes
import inspect
if sys.version_info < (3,):
raise Exception("Python 2 has reached end-of-life and is no longer supported by PyTorch.")

# multipy/deploy is setting this import before importing torch, this is the most
# reliable way we have to detect if we're running within deploy.
Expand Down Expand Up @@ -1739,7 +1737,7 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
if not _running_with_deploy():
from torch import compiler as compiler

class _TritonLibrary(object):
class _TritonLibrary:
lib = torch.library.Library("triton", "DEF")
ops_table: Dict[Tuple[str, str], Callable] = {}

Expand Down
2 changes: 1 addition & 1 deletion torch/_custom_op/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class CustomOp:
"""

def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
super(CustomOp, self).__init__()
super().__init__()
if not _private_access:
raise RuntimeError(
"The CustomOp constructor is private and we do not guarantee "
Expand Down
3 changes: 1 addition & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,7 @@ def write_constant(self, name, hashed):
def write_header(self):
if V.graph.aot_mode:
with open(
os.path.join(os.path.dirname(__file__), "aot_inductor_interface.cpp"),
"r",
os.path.join(os.path.dirname(__file__), "aot_inductor_interface.cpp")
) as f:
self.header.splice(f.read())
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/_linalg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def is_sparse(A):

error_str = "expected Tensor"
if not torch.jit.is_scripting():
error_str += " but got {}".format(type(A))
error_str += f" but got {type(A)}"
raise TypeError(error_str)


Expand Down
28 changes: 14 additions & 14 deletions torch/_lobpcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,18 +728,18 @@ def __init__(

def __str__(self):
lines = ["LOPBCG:"]
lines += [" iparams={}".format(self.iparams)]
lines += [" fparams={}".format(self.fparams)]
lines += [" bparams={}".format(self.bparams)]
lines += [" ivars={}".format(self.ivars)]
lines += [" fvars={}".format(self.fvars)]
lines += [" bvars={}".format(self.bvars)]
lines += [" tvars={}".format(self.tvars)]
lines += [" A={}".format(self.A)]
lines += [" B={}".format(self.B)]
lines += [" iK={}".format(self.iK)]
lines += [" X={}".format(self.X)]
lines += [" E={}".format(self.E)]
lines += [f" iparams={self.iparams}"]
lines += [f" fparams={self.fparams}"]
lines += [f" bparams={self.bparams}"]
lines += [f" ivars={self.ivars}"]
lines += [f" fvars={self.fvars}"]
lines += [f" bvars={self.bvars}"]
lines += [f" tvars={self.tvars}"]
lines += [f" A={self.A}"]
lines += [f" B={self.B}"]
lines += [f" iK={self.iK}"]
lines += [f" X={self.X}"]
lines += [f" E={self.E}"]
r = ""
for line in lines:
r += line + "\n"
Expand Down Expand Up @@ -1133,15 +1133,15 @@ def _get_ortho(self, U, V):
R_norm = torch.norm(R)
# https://github.com/pytorch/pytorch/issues/33810 workaround:
rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
vkey = "ortho_UBUmI_rerr[{}, {}]".format(i, j)
vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
self.fvars[vkey] = rerr
if rerr < tau_ortho:
break
VBU = mm(_utils.transpose(V), BU)
VBU_norm = torch.norm(VBU)
U_norm = torch.norm(U)
rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
vkey = "ortho_VBU_rerr[{}]".format(i)
vkey = f"ortho_VBU_rerr[{i}]"
self.fvars[vkey] = rerr
if rerr < tau_ortho:
break
Expand Down
2 changes: 1 addition & 1 deletion torch/_logging/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _invalid_settings_err_msg(settings):
return msg


@functools.lru_cache()
@functools.lru_cache
def _parse_log_settings(settings):
if settings == "":
return dict()
Expand Down
2 changes: 1 addition & 1 deletion torch/_lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def pca_lowrank(
" and not greater than min(m, n)={}".format(q, min(m, n))
)
if not (niter >= 0):
raise ValueError("niter(={}) must be non-negative integer".format(niter))
raise ValueError(f"niter(={niter}) must be non-negative integer")

dtype = _utils.get_floating_dtype(A)

Expand Down
2 changes: 1 addition & 1 deletion torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def __getattr__(self, op_name):
# Get the op `my_namespace::my_op` if available. This will also check
# for overloads and raise an exception if there are more than one.
namespace_name = self.name
qualified_op_name = "{}::{}".format(namespace_name, op_name)
qualified_op_name = f"{namespace_name}::{op_name}"
try:
op, overload_names = torch._C._jit_get_operation(qualified_op_name)
except RuntimeError as e:
Expand Down
6 changes: 3 additions & 3 deletions torch/_python_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def register(self, dispatchKeys):
"""

def _format_line(self, key, kernel):
return "{:<15} {}\n".format(key, kernel)
return f"{key:<15} {kernel}\n"

"""
Helper function to print a table header.
Expand All @@ -136,15 +136,15 @@ def _format_header(self, header):
"""

def rawRegistrations(self):
return C._dispatch_dump("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined]
return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined]

"""
Returns raw output of computed dispatch table for debugging only.
Use dispatchTable() for a simplified version.
"""

def rawDispatchTable(self):
return C._dispatch_dump_table("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined]
return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined]

"""
Returns a table(str) including all the registrations from users.
Expand Down
4 changes: 1 addition & 3 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,9 +1380,7 @@ def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
elif self.device.type == "xpu":
device_type = DLDeviceType.kDLOneAPI
else:
raise ValueError(
"Unknown device type {} for Dlpack".format(torch_device_type)
)
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
return (device_type, idx)

__module__ = "torch"
Expand Down
Loading

0 comments on commit 79c5e33

Please sign in to comment.