diff --git a/test/data/test_functional.py b/test/data/test_functional.py index 66fda21154..509a0e9fc3 100644 --- a/test/data/test_functional.py +++ b/test/data/test_functional.py @@ -107,13 +107,24 @@ def test_BasicEnglishNormalize(self): self.assertEqual(eager_tokens, ref_results) self.assertEqual(experimental_eager_tokens, ref_results) - # test load and save - save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt') - torch.save(basic_eng_norm.to_ivalue(), save_path) - loaded_basic_eng_norm = torch.load(save_path) + def test_basicEnglishNormalize_load_and_save(self): + test_sample = '\'".
,()!?;: Basic English Normalization for a Line of Text \'".
,()!?;:' + ref_results = ["'", '.', ',', '(', ')', '!', '?', 'basic', 'english', 'normalization', + 'for', 'a', 'line', 'of', 'text', "'", '.', ',', '(', ')', '!', '?'] - loaded_eager_tokens = loaded_basic_eng_norm(test_sample) - self.assertEqual(loaded_eager_tokens, ref_results) + with self.subTest('pybind'): + save_path = os.path.join(self.test_dir, 'ben_pybind.pt') + ben = basic_english_normalize() + torch.save(ben, save_path) + loaded_ben = torch.load(save_path) + self.assertEqual(loaded_ben(test_sample), ref_results) + + with self.subTest('torchscript'): + save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt') + ben = basic_english_normalize().to_ivalue() + torch.save(ben, save_path) + loaded_ben = torch.load(save_path) + self.assertEqual(loaded_ben(test_sample), ref_results) # TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") @@ -147,13 +158,39 @@ def test_RegexTokenizer(self): self.assertEqual(eager_tokens, ref_results) self.assertEqual(jit_tokens, ref_results) - # test load and save - save_path = os.path.join(self.test_dir, 'regex.pt') - torch.save(r_tokenizer.to_ivalue(), save_path) - loaded_r_tokenizer = torch.load(save_path) + def test_load_and_save(self): + test_sample = '\'".
,()!?;: Basic Regex Tokenization for a Line of Text \'".
,()!?;:' + ref_results = ["'", '.', ',', '(', ')', '!', '?', 'Basic', 'Regex', 'Tokenization', + 'for', 'a', 'Line', 'of', 'Text', "'", '.', ',', '(', ')', '!', '?'] + patterns_list = [ + (r'\'', ' \' '), + (r'\"', ''), + (r'\.', ' . '), + (r'
', ' '), + (r',', ' , '), + (r'\(', ' ( '), + (r'\)', ' ) '), + (r'\!', ' ! '), + (r'\?', ' ? '), + (r'\;', ' '), + (r'\:', ' '), + (r'\s+', ' ')] - loaded_eager_tokens = loaded_r_tokenizer(test_sample) - self.assertEqual(loaded_eager_tokens, ref_results) + with self.subTest('pybind'): + save_path = os.path.join(self.test_dir, 'regex_pybind.pt') + tokenizer = regex_tokenizer(patterns_list) + torch.save(tokenizer, save_path) + loaded_tokenizer = torch.load(save_path) + results = loaded_tokenizer(test_sample) + self.assertEqual(results, ref_results) + + with self.subTest('torchscript'): + save_path = os.path.join(self.test_dir, 'regex_torchscript.pt') + tokenizer = regex_tokenizer(patterns_list).to_ivalue() + torch.save(tokenizer, save_path) + loaded_tokenizer = torch.load(save_path) + results = loaded_tokenizer(test_sample) + self.assertEqual(results, ref_results) def test_custom_replace(self): custom_replace_transform = custom_replace([(r'S', 's'), (r'\s+', ' ')]) diff --git a/test/experimental/test_transforms.py b/test/experimental/test_transforms.py index 69816dd6e7..1994c3c396 100644 --- a/test/experimental/test_transforms.py +++ b/test/experimental/test_transforms.py @@ -54,3 +54,27 @@ def test_vector_transform(self): [-0.32423, -0.098845, -0.0073467]]) self.assertEqual(vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en) self.assertEqual(jit_vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en) + + def test_sentencepiece_load_and_save(self): + model_path = get_asset_path('spm_example.model') + input = 'SentencePiece is an unsupervised text tokenizer and detokenizer' + expected = [ + '▁Sent', 'ence', 'P', 'ie', 'ce', '▁is', + '▁an', '▁un', 'super', 'vis', 'ed', '▁text', + '▁to', 'ken', 'izer', '▁and', + '▁de', 'to', 'ken', 'izer', + ] + + with self.subTest('pybind'): + save_path = os.path.join(self.test_dir, 'spm_pybind.pt') + spm = sentencepiece_tokenizer((model_path)) + torch.save(spm, save_path) + loaded_spm = torch.load(save_path) + self.assertEqual(expected, loaded_spm(input)) + + with self.subTest('torchscript'): + save_path = os.path.join(self.test_dir, 'spm_torchscript.pt') + spm = sentencepiece_tokenizer((model_path)).to_ivalue() + torch.save(spm, save_path) + loaded_spm = torch.load(save_path) + self.assertEqual(expected, loaded_spm(input)) diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py index 946c436e95..ff5293b402 100644 --- a/test/experimental/test_vectors.py +++ b/test/experimental/test_vectors.py @@ -111,25 +111,49 @@ def test_vectors_add_item(self): self.assertEqual(vectors_obj['b'], tensorB) self.assertEqual(vectors_obj['not_in_it'], unk_tensor) - def test_vectors_load_and_save(self): + def test_vectors_update(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) + tensorC = torch.tensor([1, 1], dtype=torch.float) + expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = build_vectors(tokens, vecs) - tensorC = torch.tensor([1, 1], dtype=torch.float) vectors_obj['b'] = tensorC - vector_path = os.path.join(self.test_dir, 'vectors.pt') - torch.save(vectors_obj.to_ivalue(), vector_path) - loaded_vectors_obj = torch.load(vector_path) + self.assertEqual(vectors_obj['a'], tensorA) + self.assertEqual(vectors_obj['b'], tensorC) + self.assertEqual(vectors_obj['not_in_it'], expected_unk_tensor) + + def test_vectors_load_and_save(self): + tensorA = torch.tensor([1, 0], dtype=torch.float) + tensorB = torch.tensor([0, 1], dtype=torch.float) + expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float) + + tokens = ['a', 'b'] + vecs = torch.stack((tensorA, tensorB), 0) + vectors_obj = build_vectors(tokens, vecs) + + with self.subTest('pybind'): + vector_path = os.path.join(self.test_dir, 'vectors_pybind.pt') + torch.save(vectors_obj, vector_path) + loaded_vectors_obj = torch.load(vector_path) + + self.assertEqual(loaded_vectors_obj['a'], tensorA) + self.assertEqual(loaded_vectors_obj['b'], tensorB) + self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor) + + with self.subTest('torchscript'): + vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt') + torch.save(vectors_obj.to_ivalue(), vector_path) + loaded_vectors_obj = torch.load(vector_path) - self.assertEqual(loaded_vectors_obj['a'], tensorA) - self.assertEqual(loaded_vectors_obj['b'], tensorC) - self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor) + self.assertEqual(loaded_vectors_obj['a'], tensorA) + self.assertEqual(loaded_vectors_obj['b'], tensorB) + self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor) # we separate out these errors because Windows runs into seg faults when propagating # exceptions from C++ using pybind11 diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 626db2e726..662aa6667a 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -199,12 +199,19 @@ def test_vocab_load_and_save(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) - vocab_path = os.path.join(self.test_dir, 'vocab.pt') - torch.save(v.to_ivalue(), vocab_path) - loaded_v = torch.load(vocab_path) - - self.assertEqual(v.get_itos(), expected_itos) - self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + with self.subTest('pybind'): + vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt') + torch.save(v, vocab_path) + loaded_v = torch.load(vocab_path) + self.assertEqual(v.get_itos(), expected_itos) + self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + + with self.subTest('torchscript'): + vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt') + torch.save(v.to_ivalue(), vocab_path) + loaded_v = torch.load(vocab_path) + self.assertEqual(v.get_itos(), expected_itos) + self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/regex.cpp b/torchtext/csrc/regex.cpp index f2051e0de0..32da33d2cf 100644 --- a/torchtext/csrc/regex.cpp +++ b/torchtext/csrc/regex.cpp @@ -11,4 +11,12 @@ std::string Regex::Sub(std::string str, const std::string &repl) const { return str; } +std::string _serialize_regex(const c10::intrusive_ptr &self) { + return self->re_str_; +} + +c10::intrusive_ptr _deserialize_regex(std::string &&state) { + return c10::make_intrusive(std::move(state)); +} + } // namespace torchtext diff --git a/torchtext/csrc/regex.h b/torchtext/csrc/regex.h index c5d871479c..4e5dfbfee4 100644 --- a/torchtext/csrc/regex.h +++ b/torchtext/csrc/regex.h @@ -13,4 +13,8 @@ struct Regex : torch::CustomClassHolder { Regex(const std::string &re_str); std::string Sub(std::string str, const std::string &repl) const; }; + +std::string _serialize_regex(const c10::intrusive_ptr &self); +c10::intrusive_ptr _deserialize_regex(std::string &&state); + } // namespace torchtext diff --git a/torchtext/csrc/regex_tokenizer.cpp b/torchtext/csrc/regex_tokenizer.cpp index c31ca36226..9ad20df9a7 100644 --- a/torchtext/csrc/regex_tokenizer.cpp +++ b/torchtext/csrc/regex_tokenizer.cpp @@ -44,4 +44,15 @@ void RegexTokenizer::split_(std::string &str, std::vector &tokens, } } +RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr &self) { + return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_); +} + +c10::intrusive_ptr _deserialize_regex_tokenizer(RegexTokenizerStates &&states) { + return c10::make_intrusive( + std::move(std::get<0>(states)), + std::move(std::get<1>(states)), + std::get<2>(states)); +} + } // namespace torchtext diff --git a/torchtext/csrc/regex_tokenizer.h b/torchtext/csrc/regex_tokenizer.h index d0d9cfbb62..02d898eb4a 100644 --- a/torchtext/csrc/regex_tokenizer.h +++ b/torchtext/csrc/regex_tokenizer.h @@ -3,6 +3,9 @@ namespace torchtext { +typedef std::tuple, std::vector, bool> + RegexTokenizerStates; + struct RegexTokenizer : torch::CustomClassHolder { private: std::vector compiled_patterns_; @@ -20,4 +23,7 @@ struct RegexTokenizer : torch::CustomClassHolder { std::vector forward(std::string str) const; }; +RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr &self); +c10::intrusive_ptr _deserialize_regex_tokenizer(RegexTokenizerStates &&states); + } // namespace torchtext diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index de55d27e7d..4c3ef76399 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -27,18 +27,36 @@ Vocab build_vocab_from_text_file(const std::string &file_path, // Registers our custom classes with pybind11. PYBIND11_MODULE(_torchtext, m) { // Classes - py::class_(m, "Regex") + py::class_>(m, "Regex") .def(py::init()) - .def("Sub", &Regex::Sub); + .def("Sub", &Regex::Sub) + .def(py::pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> std::string { + return _serialize_regex(self); + }, + // __setstate__ + [](std::string state) -> c10::intrusive_ptr { + return _deserialize_regex(std::move(state)); + })); - py::class_(m, "RegexTokenizer") + py::class_>(m, "RegexTokenizer") .def_readonly("patterns_", &RegexTokenizer::patterns_) .def_readonly("replacements_", &RegexTokenizer::replacements_) .def_readonly("to_lower_", &RegexTokenizer::to_lower_) .def(py::init, std::vector, bool>()) - .def("forward", &RegexTokenizer::forward); + .def("forward", &RegexTokenizer::forward) + .def(py::pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> RegexTokenizerStates { + return _serialize_regex_tokenizer(self); + }, + // __setstate__ + [](RegexTokenizerStates states) -> c10::intrusive_ptr { + return _deserialize_regex_tokenizer(std::move(states)); + })); - py::class_(m, "SentencePiece") + py::class_>(m, "SentencePiece") .def(py::init()) .def("_return_content", [](const SentencePiece &self) { return py::bytes(self.content_); }) @@ -50,9 +68,18 @@ PYBIND11_MODULE(_torchtext, m) { .def("GetPieceSize", &SentencePiece::GetPieceSize) .def("unk_id", &SentencePiece::unk_id) .def("PieceToId", &SentencePiece::PieceToId) - .def("IdToPiece", &SentencePiece::IdToPiece); + .def("IdToPiece", &SentencePiece::IdToPiece) + .def(py::pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> py::bytes{ + return py::bytes(self->content_); + }, + // __setstate__ + [](py::bytes state) -> c10::intrusive_ptr { + return c10::make_intrusive(std::string(state)); + })); - py::class_(m, "Vectors") + py::class_>(m, "Vectors") .def(py::init, std::vector, torch::Tensor, torch::Tensor>()) .def_readonly("vectors_", &Vectors::vectors_) @@ -61,9 +88,18 @@ PYBIND11_MODULE(_torchtext, m) { .def("__getitem__", &Vectors::__getitem__) .def("lookup_vectors", &Vectors::lookup_vectors) .def("__setitem__", &Vectors::__setitem__) - .def("__len__", &Vectors::__len__); + .def("__len__", &Vectors::__len__) + .def(py::pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VectorsStates { + return _serialize_vectors(self); + }, + // __setstate__ + [](VectorsStates states) -> c10::intrusive_ptr { + return _deserialize_vectors(states); + })); - py::class_(m, "Vocab") + py::class_>(m, "Vocab") .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) @@ -75,7 +111,16 @@ PYBIND11_MODULE(_torchtext, m) { .def("lookup_tokens", &Vocab::lookup_tokens) .def("lookup_indices", &Vocab::lookup_indices) .def("get_stoi", &Vocab::get_stoi) - .def("get_itos", &Vocab::get_itos); + .def("get_itos", &Vocab::get_itos) + .def(py::pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VocabStates { + return _serialize_vocab(self); + }, + // __setstate__ + [](VocabStates states) -> c10::intrusive_ptr { + return _deserialize_vocab(states); + })); // Functions m.def("_load_token_and_vectors_from_file", @@ -91,32 +136,26 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def_pickle( // __getstate__ [](const c10::intrusive_ptr &self) -> std::string { - return self->re_str_; + return _serialize_regex(self); }, // __setstate__ [](std::string state) -> c10::intrusive_ptr { - return c10::make_intrusive(std::move(state)); + return _deserialize_regex(std::move(state)); }); - using RegexTokenizerState = std::tuple, std::vector, bool>; m.class_("RegexTokenizer") .def(torch::init, std::vector, bool>()) .def("forward", &RegexTokenizer::forward) .def_pickle( // __getstate__ - [](const c10::intrusive_ptr &self) -> RegexTokenizerState { - return std::make_tuple( - self->patterns_, - self->replacements_, - self->to_lower_); + [](const c10::intrusive_ptr &self) -> RegexTokenizerStates { + return _serialize_regex_tokenizer(self); }, // __setstate__ - [](RegexTokenizerState state) -> c10::intrusive_ptr { - return c10::make_intrusive( - std::move(std::get<0>(state)), - std::move(std::get<1>(state)), - std::get<2>(state)); + [](RegexTokenizerStates states) -> c10::intrusive_ptr { + return _deserialize_regex_tokenizer(std::move(states)); }); + m.class_("SentencePiece") .def(torch::init()) .def("Encode", &SentencePiece::Encode) @@ -129,13 +168,21 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("PieceToId", &SentencePiece::PieceToId) .def("IdToPiece", &SentencePiece::IdToPiece) .def_pickle( + // The underlying content of SentencePiece contains byte string, + // and returing it as std::string cause UTF8 decoding error. + // Since TorchScript does not support byte string, we use byte Tensor to + // pass around the data. // __getstate__ - [](const c10::intrusive_ptr &self) -> std::string { - return self->content_; + [](const c10::intrusive_ptr &self) -> torch::Tensor { + auto *data = static_cast(const_cast(self->content_.data())); + auto numel = static_cast(self->content_.size()); + return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); }, // __setstate__ - [](std::string state) -> c10::intrusive_ptr { - return c10::make_intrusive(std::move(state)); + [](torch::Tensor state) -> c10::intrusive_ptr { + auto *data = static_cast(state.data_ptr()); + auto numel = state.size(0); + return c10::make_intrusive(std::string(data, numel)); }); m.class_("Vectors") @@ -147,11 +194,11 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def_pickle( // __getstate__ [](const c10::intrusive_ptr &self) -> VectorsStates { - return _set_vectors_states(self); + return _serialize_vectors(self); }, // __setstate__ [](VectorsStates states) -> c10::intrusive_ptr { - return _get_vectors_from_states(states); + return _deserialize_vectors(states); }); m.class_("Vocab") @@ -168,11 +215,11 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def_pickle( // __getstate__ [](const c10::intrusive_ptr &self) -> VocabStates { - return _set_vocab_states(self); + return _serialize_vocab(self); }, // __setstate__ [](VocabStates states) -> c10::intrusive_ptr { - return _get_vocab_from_states(states); + return _deserialize_vocab(states); }); m.def("torchtext::generate_sp_model", &generate_sp_model); diff --git a/torchtext/csrc/vectors.cpp b/torchtext/csrc/vectors.cpp index 029b3abb0d..c5410b1496 100644 --- a/torchtext/csrc/vectors.cpp +++ b/torchtext/csrc/vectors.cpp @@ -275,7 +275,7 @@ std::tuple> _load_token_and_vectors_from_file( return result; } -VectorsStates _set_vectors_states(const c10::intrusive_ptr &self) { +VectorsStates _serialize_vectors(const c10::intrusive_ptr &self) { std::vector tokens; std::vector indices; tokens.reserve(self->stoi_.size()); @@ -299,7 +299,7 @@ VectorsStates _set_vectors_states(const c10::intrusive_ptr &self) { return states; } -c10::intrusive_ptr _get_vectors_from_states(VectorsStates states) { +c10::intrusive_ptr _deserialize_vectors(VectorsStates states) { auto state_size = std::tuple_size::value; if (state_size != 4) { throw std::runtime_error( diff --git a/torchtext/csrc/vectors.h b/torchtext/csrc/vectors.h index fde1f257f0..09716258ca 100644 --- a/torchtext/csrc/vectors.h +++ b/torchtext/csrc/vectors.h @@ -32,8 +32,8 @@ struct Vectors : torch::CustomClassHolder { int64_t __len__(); }; -c10::intrusive_ptr _get_vectors_from_states(VectorsStates states); -VectorsStates _set_vectors_states(const c10::intrusive_ptr &self); +VectorsStates _serialize_vectors(const c10::intrusive_ptr &self); +c10::intrusive_ptr _deserialize_vectors(VectorsStates states); std::tuple> _load_token_and_vectors_from_file( const std::string &file_path, const std::string delimiter_str, diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 458f8c3e47..0e324dbbf5 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -379,7 +379,7 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); } -VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { +VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { std::vector integers; StringList strings = self->itos_; strings.push_back(self->unk_token_); @@ -390,7 +390,7 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { return states; } -c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { +c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; if (state_size != 4) { #ifdef _MSC_VER diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 195d6de1a0..0da660a633 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -35,8 +35,9 @@ struct Vocab : torch::CustomClassHolder { std::vector get_itos() const; }; -c10::intrusive_ptr _get_vocab_from_states(VocabStates states); -VocabStates _set_vocab_states(const c10::intrusive_ptr &self); +VocabStates _serialize_vocab(const c10::intrusive_ptr &self); +c10::intrusive_ptr _deserialize_vocab(VocabStates states); + Vocab _load_vocab_from_file(const std::string &file_path, const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus);