Skip to content

Commit

Permalink
Showing 16 changed files with 197 additions and 188 deletions.
3 changes: 1 addition & 2 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
@@ -541,8 +541,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)
20 changes: 0 additions & 20 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -907,17 +907,6 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
}
}

Tensor stft(
const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
const bool normalized,
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
return at::stft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
return_complexOpt);
}

// Create complex tensor from the old style of real tensor with size=(..., 2)
// This is to support istft in the transition to requiring complex input.
// NOTE: This may return a view of the input tensor, or might clone if necessary
@@ -1111,15 +1100,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
#undef REPR
}

Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool center, const bool normalized, const optional<bool> onesidedOpt,
const optional<int64_t> lengthOpt) {
return at::native::istft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
onesidedOpt, lengthOpt, /*return_complex=*/false);
}

void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
const auto input_sizes = input.sizes();
const auto input_strides = input.strides();
7 changes: 1 addition & 6 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
@@ -4297,12 +4297,7 @@

- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)

# Overload without center & pad mode, needed for forward-compatibility
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']

- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method

- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
6 changes: 4 additions & 2 deletions caffe2/serialize/versions.h
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;

#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
constexpr uint64_t kMaxSupportedFileFormatVersion = 11;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
@@ -83,7 +83,9 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// Bump the version number to 10 to update aten::gelu and
// and aten::gelu.out to support the new approximate kwarg.
// (see: https://github.com/pytorch/pytorch/pull/61439)
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
// 4) [02/25/2022]
// Bump version number to 11 to update aten::stft to do padding in ATen
constexpr uint64_t kProducedFileFormatVersion = 11L;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif
Original file line number Diff line number Diff line change
@@ -110,6 +110,7 @@
("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
("aten::scatter_reduce.two", datetime.date(2022, 4, 15)),
("aten::stft", datetime.date(2022, 5, 1)),
("aten::_s_where", datetime.date(2022, 9, 30)),
("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)),
("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)),
Binary file added test/jit/fixtures/test_versioned_stft_v10.ptl
Binary file not shown.
8 changes: 8 additions & 0 deletions test/jit/fixtures_srcs/fixtures_src.py
Original file line number Diff line number Diff line change
@@ -57,3 +57,11 @@ def __init__(self):
def forward(self, x):
out = torch.zeros_like(x)
return torch._C._nn.gelu(x, out=out)

class TestVersionedStftV10(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, n_fft: int, window):
# calling aten::stft direct instead of torch.functional.stft
return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True)
1 change: 1 addition & 0 deletions test/jit/fixtures_srcs/generate_models.py
Original file line number Diff line number Diff line change
@@ -96,6 +96,7 @@ def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
TestVersionedLogspaceOutV8(): "aten::logspace.out",
TestVersionedGeluV9(): "aten::gelu",
TestVersionedGeluOutV9(): "aten::gelu.out",
TestVersionedStftV10(): "aten::stft",
}

"""
17 changes: 17 additions & 0 deletions test/jit/test_save_load_for_op_version.py
Original file line number Diff line number Diff line change
@@ -540,3 +540,20 @@ def forward(self, a: Union[int, float, complex], b: Union[int, float, complex],
self.assertTrue(output.size(dim=0) == 100)
# "Upgraded" model should match the new version output
self.assertEqual(output, output_current)

def test_versioned_stft_v10(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v10_mobile_module = _load_for_lite_interpreter(buffer)

for in_dtype, window_dtype in product(
[torch.float32, torch.complex64], repeat=2):
input = torch.rand((100,), dtype=in_dtype)
window = torch.rand((10,), dtype=window_dtype)
n_fft = 10
output = v10_mobile_module(input, n_fft, window)
output_expected = torch.stft(input, n_fft=n_fft, window=window,
center=False, return_complex=True)
self.assertEqual(output, output_expected)
1 change: 0 additions & 1 deletion tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,6 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
"block_diag",
"norm",
"chain_matmul",
"stft",
"tensordot",
"split",
"unique_consecutive",
36 changes: 1 addition & 35 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import enum
import functools
from numbers import Number
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
import warnings
import copyreg
from copy import deepcopy
@@ -542,40 +542,6 @@ def lu(self, pivot=True, get_infos=False):
else:
return LU, pivots

def stft(self, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
onesided: Optional[bool] = None, return_complex: Optional[bool] = None):
r"""See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex
)
return torch.stft(self, n_fft, hop_length, win_length, window, center,
pad_mode, normalized, onesided, return_complex=return_complex)

def istft(self, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
center: bool = True, normalized: bool = False,
onesided: Optional[bool] = None, length: Optional[int] = None,
return_complex: bool = False):
r"""See :func:`torch.istft`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided, length=length,
return_complex=return_complex
)
return torch.istft(self, n_fft, hop_length, win_length, window, center,
normalized, onesided, length, return_complex=return_complex)

def resize(self, *sizes):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.resize, (self,), self, *sizes)
11 changes: 8 additions & 3 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
@@ -4752,16 +4752,21 @@ def callable(a, b) -> number
""")

add_docstr_all('stft',
"stft(n_fft, hop_length=None, win_length=None, window=None, center=True, "
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
r"""
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
""")

add_docstr_all('istft',
"istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
"normalized=False, onesided=None, length=None, return_complex=False) -> Tensor"
r"""
istft(n_fft, hop_length=None, win_length=None, window=None,
center=True, normalized=False, onesided=True, length=None) -> Tensor
See :func:`torch.istft`
""")
33 changes: 33 additions & 0 deletions torch/csrc/jit/mobile/upgrader_mobile.cpp
Original file line number Diff line number Diff line change
@@ -67,6 +67,10 @@ getOperatorVersionMapForMobile() {
std::vector<Upgrader>({
Upgrader({0, 8, "logspace_out_0_8", 10})
})},
{std::string("aten::stft"),
std::vector<Upgrader>({
Upgrader({0, 10, "stft_0_10", 11})
})},
});
return operatorVersionMapForMobile;
}
@@ -527,6 +531,35 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
}),
ByteCodeFunctionWithOperator({
mobile::Function::registerFunc(
"stft_0_10",
std::vector<Instruction>({
Instruction{OpCode::STOREN, 1, 8},
Instruction{OpCode::MOVE, 1, 0},
Instruction{OpCode::MOVE, 2, 0},
Instruction{OpCode::MOVE, 3, 0},
Instruction{OpCode::MOVE, 4, 0},
Instruction{OpCode::MOVE, 5, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::MOVE, 6, 0},
Instruction{OpCode::MOVE, 7, 0},
Instruction{OpCode::MOVE, 8, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::RET, 0, 0},
}), // instructions list,
std::vector<c10::IValue>({
c10::IValue("reflect"),
c10::IValue(false),
}), // constants list,
std::vector<c10::TypePtr>(), // types list,
8
),
std::vector<OperatorString>({
OperatorString({"aten::stft", "", 10}),
}), // operators list
}),
});
for (const auto& upgrader_function : upgrader_function_list) {
for (const auto& op : upgrader_function.operators) {
11 changes: 11 additions & 0 deletions torch/csrc/jit/operator_upgraders/upgraders_entry.cpp
Original file line number Diff line number Diff line change
@@ -15,6 +15,17 @@ namespace torch {
namespace jit {

static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
{"stft_0_10", R"SCRIPT(
def stft_0_10(
self: Tensor, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: Optional[Tensor] = None,
normalized: bool = False, onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor:
return torch.stft(
self, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=False, normalized=normalized, onesided=onesided,
return_complex=return_complex)
)SCRIPT"},
{"logspace_0_8", R"SCRIPT(
def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
device: Optional[Device], pin_memory: Optional[bool]):
6 changes: 5 additions & 1 deletion torch/csrc/jit/operator_upgraders/version_map.cpp
Original file line number Diff line number Diff line change
@@ -16,7 +16,11 @@ static bool isVersionMapSorted = false;
// Note for developers: The list of upgraders should be SORTED
// by the version number where the upgrader is registered.
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
{{"aten::logspace",
{{"aten::stft",
{{11,
"stft_0_10",
"aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}},
{"aten::logspace",
{{9,
"logspace_0_8",
"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
Loading
Oops, something went wrong.

0 comments on commit 6b7d89c

Please sign in to comment.