diff --git a/test/data/test_functional.py b/test/data/test_functional.py
index 66fda21154..509a0e9fc3 100644
--- a/test/data/test_functional.py
+++ b/test/data/test_functional.py
@@ -107,13 +107,24 @@ def test_BasicEnglishNormalize(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(experimental_eager_tokens, ref_results)
- # test load and save
- save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt')
- torch.save(basic_eng_norm.to_ivalue(), save_path)
- loaded_basic_eng_norm = torch.load(save_path)
+ def test_basicEnglishNormalize_load_and_save(self):
+ test_sample = '\'".
,()!?;: Basic English Normalization for a Line of Text \'".
,()!?;:'
+ ref_results = ["'", '.', ',', '(', ')', '!', '?', 'basic', 'english', 'normalization',
+ 'for', 'a', 'line', 'of', 'text', "'", '.', ',', '(', ')', '!', '?']
- loaded_eager_tokens = loaded_basic_eng_norm(test_sample)
- self.assertEqual(loaded_eager_tokens, ref_results)
+ with self.subTest('pybind'):
+ save_path = os.path.join(self.test_dir, 'ben_pybind.pt')
+ ben = basic_english_normalize()
+ torch.save(ben, save_path)
+ loaded_ben = torch.load(save_path)
+ self.assertEqual(loaded_ben(test_sample), ref_results)
+
+ with self.subTest('torchscript'):
+ save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
+ ben = basic_english_normalize().to_ivalue()
+ torch.save(ben, save_path)
+ loaded_ben = torch.load(save_path)
+ self.assertEqual(loaded_ben(test_sample), ref_results)
# TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
@@ -147,13 +158,39 @@ def test_RegexTokenizer(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(jit_tokens, ref_results)
- # test load and save
- save_path = os.path.join(self.test_dir, 'regex.pt')
- torch.save(r_tokenizer.to_ivalue(), save_path)
- loaded_r_tokenizer = torch.load(save_path)
+ def test_load_and_save(self):
+ test_sample = '\'".
,()!?;: Basic Regex Tokenization for a Line of Text \'".
,()!?;:'
+ ref_results = ["'", '.', ',', '(', ')', '!', '?', 'Basic', 'Regex', 'Tokenization',
+ 'for', 'a', 'Line', 'of', 'Text', "'", '.', ',', '(', ')', '!', '?']
+ patterns_list = [
+ (r'\'', ' \' '),
+ (r'\"', ''),
+ (r'\.', ' . '),
+ (r'
', ' '),
+ (r',', ' , '),
+ (r'\(', ' ( '),
+ (r'\)', ' ) '),
+ (r'\!', ' ! '),
+ (r'\?', ' ? '),
+ (r'\;', ' '),
+ (r'\:', ' '),
+ (r'\s+', ' ')]
- loaded_eager_tokens = loaded_r_tokenizer(test_sample)
- self.assertEqual(loaded_eager_tokens, ref_results)
+ with self.subTest('pybind'):
+ save_path = os.path.join(self.test_dir, 'regex_pybind.pt')
+ tokenizer = regex_tokenizer(patterns_list)
+ torch.save(tokenizer, save_path)
+ loaded_tokenizer = torch.load(save_path)
+ results = loaded_tokenizer(test_sample)
+ self.assertEqual(results, ref_results)
+
+ with self.subTest('torchscript'):
+ save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
+ tokenizer = regex_tokenizer(patterns_list).to_ivalue()
+ torch.save(tokenizer, save_path)
+ loaded_tokenizer = torch.load(save_path)
+ results = loaded_tokenizer(test_sample)
+ self.assertEqual(results, ref_results)
def test_custom_replace(self):
custom_replace_transform = custom_replace([(r'S', 's'), (r'\s+', ' ')])
diff --git a/test/experimental/test_transforms.py b/test/experimental/test_transforms.py
index 69816dd6e7..1994c3c396 100644
--- a/test/experimental/test_transforms.py
+++ b/test/experimental/test_transforms.py
@@ -54,3 +54,27 @@ def test_vector_transform(self):
[-0.32423, -0.098845, -0.0073467]])
self.assertEqual(vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)
self.assertEqual(jit_vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)
+
+ def test_sentencepiece_load_and_save(self):
+ model_path = get_asset_path('spm_example.model')
+ input = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
+ expected = [
+ '▁Sent', 'ence', 'P', 'ie', 'ce', '▁is',
+ '▁an', '▁un', 'super', 'vis', 'ed', '▁text',
+ '▁to', 'ken', 'izer', '▁and',
+ '▁de', 'to', 'ken', 'izer',
+ ]
+
+ with self.subTest('pybind'):
+ save_path = os.path.join(self.test_dir, 'spm_pybind.pt')
+ spm = sentencepiece_tokenizer((model_path))
+ torch.save(spm, save_path)
+ loaded_spm = torch.load(save_path)
+ self.assertEqual(expected, loaded_spm(input))
+
+ with self.subTest('torchscript'):
+ save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
+ spm = sentencepiece_tokenizer((model_path)).to_ivalue()
+ torch.save(spm, save_path)
+ loaded_spm = torch.load(save_path)
+ self.assertEqual(expected, loaded_spm(input))
diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py
index 946c436e95..ff5293b402 100644
--- a/test/experimental/test_vectors.py
+++ b/test/experimental/test_vectors.py
@@ -111,25 +111,49 @@ def test_vectors_add_item(self):
self.assertEqual(vectors_obj['b'], tensorB)
self.assertEqual(vectors_obj['not_in_it'], unk_tensor)
- def test_vectors_load_and_save(self):
+ def test_vectors_update(self):
tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)
+ tensorC = torch.tensor([1, 1], dtype=torch.float)
+
expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)
tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs)
- tensorC = torch.tensor([1, 1], dtype=torch.float)
vectors_obj['b'] = tensorC
- vector_path = os.path.join(self.test_dir, 'vectors.pt')
- torch.save(vectors_obj.to_ivalue(), vector_path)
- loaded_vectors_obj = torch.load(vector_path)
+ self.assertEqual(vectors_obj['a'], tensorA)
+ self.assertEqual(vectors_obj['b'], tensorC)
+ self.assertEqual(vectors_obj['not_in_it'], expected_unk_tensor)
+
+ def test_vectors_load_and_save(self):
+ tensorA = torch.tensor([1, 0], dtype=torch.float)
+ tensorB = torch.tensor([0, 1], dtype=torch.float)
+ expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)
+
+ tokens = ['a', 'b']
+ vecs = torch.stack((tensorA, tensorB), 0)
+ vectors_obj = build_vectors(tokens, vecs)
+
+ with self.subTest('pybind'):
+ vector_path = os.path.join(self.test_dir, 'vectors_pybind.pt')
+ torch.save(vectors_obj, vector_path)
+ loaded_vectors_obj = torch.load(vector_path)
+
+ self.assertEqual(loaded_vectors_obj['a'], tensorA)
+ self.assertEqual(loaded_vectors_obj['b'], tensorB)
+ self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
+
+ with self.subTest('torchscript'):
+ vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
+ torch.save(vectors_obj.to_ivalue(), vector_path)
+ loaded_vectors_obj = torch.load(vector_path)
- self.assertEqual(loaded_vectors_obj['a'], tensorA)
- self.assertEqual(loaded_vectors_obj['b'], tensorC)
- self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
+ self.assertEqual(loaded_vectors_obj['a'], tensorA)
+ self.assertEqual(loaded_vectors_obj['b'], tensorB)
+ self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py
index 626db2e726..662aa6667a 100644
--- a/test/experimental/test_vocab.py
+++ b/test/experimental/test_vocab.py
@@ -199,12 +199,19 @@ def test_vocab_load_and_save(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
- vocab_path = os.path.join(self.test_dir, 'vocab.pt')
- torch.save(v.to_ivalue(), vocab_path)
- loaded_v = torch.load(vocab_path)
-
- self.assertEqual(v.get_itos(), expected_itos)
- self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
+ with self.subTest('pybind'):
+ vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
+ torch.save(v, vocab_path)
+ loaded_v = torch.load(vocab_path)
+ self.assertEqual(v.get_itos(), expected_itos)
+ self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
+
+ with self.subTest('torchscript'):
+ vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
+ torch.save(v.to_ivalue(), vocab_path)
+ loaded_v = torch.load(vocab_path)
+ self.assertEqual(v.get_itos(), expected_itos)
+ self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
diff --git a/torchtext/csrc/regex.cpp b/torchtext/csrc/regex.cpp
index f2051e0de0..32da33d2cf 100644
--- a/torchtext/csrc/regex.cpp
+++ b/torchtext/csrc/regex.cpp
@@ -11,4 +11,12 @@ std::string Regex::Sub(std::string str, const std::string &repl) const {
return str;
}
+std::string _serialize_regex(const c10::intrusive_ptr &self) {
+ return self->re_str_;
+}
+
+c10::intrusive_ptr _deserialize_regex(std::string &&state) {
+ return c10::make_intrusive(std::move(state));
+}
+
} // namespace torchtext
diff --git a/torchtext/csrc/regex.h b/torchtext/csrc/regex.h
index c5d871479c..4e5dfbfee4 100644
--- a/torchtext/csrc/regex.h
+++ b/torchtext/csrc/regex.h
@@ -13,4 +13,8 @@ struct Regex : torch::CustomClassHolder {
Regex(const std::string &re_str);
std::string Sub(std::string str, const std::string &repl) const;
};
+
+std::string _serialize_regex(const c10::intrusive_ptr &self);
+c10::intrusive_ptr _deserialize_regex(std::string &&state);
+
} // namespace torchtext
diff --git a/torchtext/csrc/regex_tokenizer.cpp b/torchtext/csrc/regex_tokenizer.cpp
index c31ca36226..9ad20df9a7 100644
--- a/torchtext/csrc/regex_tokenizer.cpp
+++ b/torchtext/csrc/regex_tokenizer.cpp
@@ -44,4 +44,15 @@ void RegexTokenizer::split_(std::string &str, std::vector &tokens,
}
}
+RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr &self) {
+ return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_);
+}
+
+c10::intrusive_ptr _deserialize_regex_tokenizer(RegexTokenizerStates &&states) {
+ return c10::make_intrusive(
+ std::move(std::get<0>(states)),
+ std::move(std::get<1>(states)),
+ std::get<2>(states));
+}
+
} // namespace torchtext
diff --git a/torchtext/csrc/regex_tokenizer.h b/torchtext/csrc/regex_tokenizer.h
index d0d9cfbb62..02d898eb4a 100644
--- a/torchtext/csrc/regex_tokenizer.h
+++ b/torchtext/csrc/regex_tokenizer.h
@@ -3,6 +3,9 @@
namespace torchtext {
+typedef std::tuple, std::vector, bool>
+ RegexTokenizerStates;
+
struct RegexTokenizer : torch::CustomClassHolder {
private:
std::vector compiled_patterns_;
@@ -20,4 +23,7 @@ struct RegexTokenizer : torch::CustomClassHolder {
std::vector forward(std::string str) const;
};
+RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr &self);
+c10::intrusive_ptr _deserialize_regex_tokenizer(RegexTokenizerStates &&states);
+
} // namespace torchtext
diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp
index de55d27e7d..4c3ef76399 100644
--- a/torchtext/csrc/register_bindings.cpp
+++ b/torchtext/csrc/register_bindings.cpp
@@ -27,18 +27,36 @@ Vocab build_vocab_from_text_file(const std::string &file_path,
// Registers our custom classes with pybind11.
PYBIND11_MODULE(_torchtext, m) {
// Classes
- py::class_(m, "Regex")
+ py::class_>(m, "Regex")
.def(py::init())
- .def("Sub", &Regex::Sub);
+ .def("Sub", &Regex::Sub)
+ .def(py::pickle(
+ // __getstate__
+ [](const c10::intrusive_ptr &self) -> std::string {
+ return _serialize_regex(self);
+ },
+ // __setstate__
+ [](std::string state) -> c10::intrusive_ptr {
+ return _deserialize_regex(std::move(state));
+ }));
- py::class_(m, "RegexTokenizer")
+ py::class_>(m, "RegexTokenizer")
.def_readonly("patterns_", &RegexTokenizer::patterns_)
.def_readonly("replacements_", &RegexTokenizer::replacements_)
.def_readonly("to_lower_", &RegexTokenizer::to_lower_)
.def(py::init, std::vector, bool>())
- .def("forward", &RegexTokenizer::forward);
+ .def("forward", &RegexTokenizer::forward)
+ .def(py::pickle(
+ // __getstate__
+ [](const c10::intrusive_ptr &self) -> RegexTokenizerStates {
+ return _serialize_regex_tokenizer(self);
+ },
+ // __setstate__
+ [](RegexTokenizerStates states) -> c10::intrusive_ptr {
+ return _deserialize_regex_tokenizer(std::move(states));
+ }));
- py::class_(m, "SentencePiece")
+ py::class_>(m, "SentencePiece")
.def(py::init())
.def("_return_content",
[](const SentencePiece &self) { return py::bytes(self.content_); })
@@ -50,9 +68,18 @@ PYBIND11_MODULE(_torchtext, m) {
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
- .def("IdToPiece", &SentencePiece::IdToPiece);
+ .def("IdToPiece", &SentencePiece::IdToPiece)
+ .def(py::pickle(
+ // __getstate__
+ [](const c10::intrusive_ptr &self) -> py::bytes{
+ return py::bytes(self->content_);
+ },
+ // __setstate__
+ [](py::bytes state) -> c10::intrusive_ptr {
+ return c10::make_intrusive(std::string(state));
+ }));
- py::class_(m, "Vectors")
+ py::class_>(m, "Vectors")
.def(py::init, std::vector,
torch::Tensor, torch::Tensor>())
.def_readonly("vectors_", &Vectors::vectors_)
@@ -61,9 +88,18 @@ PYBIND11_MODULE(_torchtext, m) {
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
- .def("__len__", &Vectors::__len__);
+ .def("__len__", &Vectors::__len__)
+ .def(py::pickle(
+ // __getstate__
+ [](const c10::intrusive_ptr &self) -> VectorsStates {
+ return _serialize_vectors(self);
+ },
+ // __setstate__
+ [](VectorsStates states) -> c10::intrusive_ptr {
+ return _deserialize_vectors(states);
+ }));
- py::class_(m, "Vocab")
+ py::class_>(m, "Vocab")
.def(py::init, std::string>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
@@ -75,7 +111,16 @@ PYBIND11_MODULE(_torchtext, m) {
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices", &Vocab::lookup_indices)
.def("get_stoi", &Vocab::get_stoi)
- .def("get_itos", &Vocab::get_itos);
+ .def("get_itos", &Vocab::get_itos)
+ .def(py::pickle(
+ // __getstate__
+ [](const c10::intrusive_ptr &self) -> VocabStates {
+ return _serialize_vocab(self);
+ },
+ // __setstate__
+ [](VocabStates states) -> c10::intrusive_ptr {
+ return _deserialize_vocab(states);
+ }));
// Functions
m.def("_load_token_and_vectors_from_file",
@@ -91,32 +136,26 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr &self) -> std::string {
- return self->re_str_;
+ return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr {
- return c10::make_intrusive(std::move(state));
+ return _deserialize_regex(std::move(state));
});
- using RegexTokenizerState = std::tuple, std::vector, bool>;
m.class_("RegexTokenizer")
.def(torch::init, std::vector, bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
- [](const c10::intrusive_ptr &self) -> RegexTokenizerState {
- return std::make_tuple(
- self->patterns_,
- self->replacements_,
- self->to_lower_);
+ [](const c10::intrusive_ptr &self) -> RegexTokenizerStates {
+ return _serialize_regex_tokenizer(self);
},
// __setstate__
- [](RegexTokenizerState state) -> c10::intrusive_ptr {
- return c10::make_intrusive(
- std::move(std::get<0>(state)),
- std::move(std::get<1>(state)),
- std::get<2>(state));
+ [](RegexTokenizerStates states) -> c10::intrusive_ptr {
+ return _deserialize_regex_tokenizer(std::move(states));
});
+
m.class_("SentencePiece")
.def(torch::init())
.def("Encode", &SentencePiece::Encode)
@@ -129,13 +168,21 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
+ // The underlying content of SentencePiece contains byte string,
+ // and returing it as std::string cause UTF8 decoding error.
+ // Since TorchScript does not support byte string, we use byte Tensor to
+ // pass around the data.
// __getstate__
- [](const c10::intrusive_ptr &self) -> std::string {
- return self->content_;
+ [](const c10::intrusive_ptr &self) -> torch::Tensor {
+ auto *data = static_cast(const_cast(self->content_.data()));
+ auto numel = static_cast(self->content_.size());
+ return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
- [](std::string state) -> c10::intrusive_ptr {
- return c10::make_intrusive(std::move(state));
+ [](torch::Tensor state) -> c10::intrusive_ptr {
+ auto *data = static_cast(state.data_ptr());
+ auto numel = state.size(0);
+ return c10::make_intrusive(std::string(data, numel));
});
m.class_("Vectors")
@@ -147,11 +194,11 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr &self) -> VectorsStates {
- return _set_vectors_states(self);
+ return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr {
- return _get_vectors_from_states(states);
+ return _deserialize_vectors(states);
});
m.class_("Vocab")
@@ -168,11 +215,11 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr &self) -> VocabStates {
- return _set_vocab_states(self);
+ return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr {
- return _get_vocab_from_states(states);
+ return _deserialize_vocab(states);
});
m.def("torchtext::generate_sp_model", &generate_sp_model);
diff --git a/torchtext/csrc/vectors.cpp b/torchtext/csrc/vectors.cpp
index 029b3abb0d..c5410b1496 100644
--- a/torchtext/csrc/vectors.cpp
+++ b/torchtext/csrc/vectors.cpp
@@ -275,7 +275,7 @@ std::tuple> _load_token_and_vectors_from_file(
return result;
}
-VectorsStates _set_vectors_states(const c10::intrusive_ptr &self) {
+VectorsStates _serialize_vectors(const c10::intrusive_ptr &self) {
std::vector tokens;
std::vector indices;
tokens.reserve(self->stoi_.size());
@@ -299,7 +299,7 @@ VectorsStates _set_vectors_states(const c10::intrusive_ptr &self) {
return states;
}
-c10::intrusive_ptr _get_vectors_from_states(VectorsStates states) {
+c10::intrusive_ptr _deserialize_vectors(VectorsStates states) {
auto state_size = std::tuple_size::value;
if (state_size != 4) {
throw std::runtime_error(
diff --git a/torchtext/csrc/vectors.h b/torchtext/csrc/vectors.h
index fde1f257f0..09716258ca 100644
--- a/torchtext/csrc/vectors.h
+++ b/torchtext/csrc/vectors.h
@@ -32,8 +32,8 @@ struct Vectors : torch::CustomClassHolder {
int64_t __len__();
};
-c10::intrusive_ptr _get_vectors_from_states(VectorsStates states);
-VectorsStates _set_vectors_states(const c10::intrusive_ptr &self);
+VectorsStates _serialize_vectors(const c10::intrusive_ptr &self);
+c10::intrusive_ptr _deserialize_vectors(VectorsStates states);
std::tuple> _load_token_and_vectors_from_file(
const std::string &file_path, const std::string delimiter_str,
diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp
index 458f8c3e47..0e324dbbf5 100644
--- a/torchtext/csrc/vocab.cpp
+++ b/torchtext/csrc/vocab.cpp
@@ -379,7 +379,7 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index);
}
-VocabStates _set_vocab_states(const c10::intrusive_ptr &self) {
+VocabStates _serialize_vocab(const c10::intrusive_ptr &self) {
std::vector integers;
StringList strings = self->itos_;
strings.push_back(self->unk_token_);
@@ -390,7 +390,7 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) {
return states;
}
-c10::intrusive_ptr _get_vocab_from_states(VocabStates states) {
+c10::intrusive_ptr _deserialize_vocab(VocabStates states) {
auto state_size = std::tuple_size::value;
if (state_size != 4) {
#ifdef _MSC_VER
diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h
index 195d6de1a0..0da660a633 100644
--- a/torchtext/csrc/vocab.h
+++ b/torchtext/csrc/vocab.h
@@ -35,8 +35,9 @@ struct Vocab : torch::CustomClassHolder {
std::vector get_itos() const;
};
-c10::intrusive_ptr _get_vocab_from_states(VocabStates states);
-VocabStates _set_vocab_states(const c10::intrusive_ptr &self);
+VocabStates _serialize_vocab(const c10::intrusive_ptr &self);
+c10::intrusive_ptr _deserialize_vocab(VocabStates states);
+
Vocab _load_vocab_from_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq, const int64_t num_cpus);