Skip to content

Commit

Permalink
Make Regex, RegexTokenizer, Vocab, Vectors, SentencePiece pickle-able (
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Dec 22, 2020
1 parent 9562f80 commit aa7176b
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 65 deletions.
61 changes: 49 additions & 12 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,24 @@ def test_BasicEnglishNormalize(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(experimental_eager_tokens, ref_results)

# test load and save
save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt')
torch.save(basic_eng_norm.to_ivalue(), save_path)
loaded_basic_eng_norm = torch.load(save_path)
def test_basicEnglishNormalize_load_and_save(self):
test_sample = '\'".<br />,()!?;: Basic English Normalization for a Line of Text \'".<br />,()!?;:'
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'basic', 'english', 'normalization',
'for', 'a', 'line', 'of', 'text', "'", '.', ',', '(', ')', '!', '?']

loaded_eager_tokens = loaded_basic_eng_norm(test_sample)
self.assertEqual(loaded_eager_tokens, ref_results)
with self.subTest('pybind'):
save_path = os.path.join(self.test_dir, 'ben_pybind.pt')
ben = basic_english_normalize()
torch.save(ben, save_path)
loaded_ben = torch.load(save_path)
self.assertEqual(loaded_ben(test_sample), ref_results)

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
ben = basic_english_normalize().to_ivalue()
torch.save(ben, save_path)
loaded_ben = torch.load(save_path)
self.assertEqual(loaded_ben(test_sample), ref_results)

# TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
Expand Down Expand Up @@ -147,13 +158,39 @@ def test_RegexTokenizer(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(jit_tokens, ref_results)

# test load and save
save_path = os.path.join(self.test_dir, 'regex.pt')
torch.save(r_tokenizer.to_ivalue(), save_path)
loaded_r_tokenizer = torch.load(save_path)
def test_load_and_save(self):
test_sample = '\'".<br />,()!?;: Basic Regex Tokenization for a Line of Text \'".<br />,()!?;:'
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'Basic', 'Regex', 'Tokenization',
'for', 'a', 'Line', 'of', 'Text', "'", '.', ',', '(', ')', '!', '?']
patterns_list = [
(r'\'', ' \' '),
(r'\"', ''),
(r'\.', ' . '),
(r'<br \/>', ' '),
(r',', ' , '),
(r'\(', ' ( '),
(r'\)', ' ) '),
(r'\!', ' ! '),
(r'\?', ' ? '),
(r'\;', ' '),
(r'\:', ' '),
(r'\s+', ' ')]

loaded_eager_tokens = loaded_r_tokenizer(test_sample)
self.assertEqual(loaded_eager_tokens, ref_results)
with self.subTest('pybind'):
save_path = os.path.join(self.test_dir, 'regex_pybind.pt')
tokenizer = regex_tokenizer(patterns_list)
torch.save(tokenizer, save_path)
loaded_tokenizer = torch.load(save_path)
results = loaded_tokenizer(test_sample)
self.assertEqual(results, ref_results)

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
torch.save(tokenizer, save_path)
loaded_tokenizer = torch.load(save_path)
results = loaded_tokenizer(test_sample)
self.assertEqual(results, ref_results)

def test_custom_replace(self):
custom_replace_transform = custom_replace([(r'S', 's'), (r'\s+', ' ')])
Expand Down
24 changes: 24 additions & 0 deletions test/experimental/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,27 @@ def test_vector_transform(self):
[-0.32423, -0.098845, -0.0073467]])
self.assertEqual(vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)
self.assertEqual(jit_vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)

def test_sentencepiece_load_and_save(self):
model_path = get_asset_path('spm_example.model')
input = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
expected = [
'▁Sent', 'ence', 'P', 'ie', 'ce', '▁is',
'▁an', '▁un', 'super', 'vis', 'ed', '▁text',
'▁to', 'ken', 'izer', '▁and',
'▁de', 'to', 'ken', 'izer',
]

with self.subTest('pybind'):
save_path = os.path.join(self.test_dir, 'spm_pybind.pt')
spm = sentencepiece_tokenizer((model_path))
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))
40 changes: 32 additions & 8 deletions test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,49 @@ def test_vectors_add_item(self):
self.assertEqual(vectors_obj['b'], tensorB)
self.assertEqual(vectors_obj['not_in_it'], unk_tensor)

def test_vectors_load_and_save(self):
def test_vectors_update(self):
tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)
tensorC = torch.tensor([1, 1], dtype=torch.float)

expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)

tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs)

tensorC = torch.tensor([1, 1], dtype=torch.float)
vectors_obj['b'] = tensorC

vector_path = os.path.join(self.test_dir, 'vectors.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
loaded_vectors_obj = torch.load(vector_path)
self.assertEqual(vectors_obj['a'], tensorA)
self.assertEqual(vectors_obj['b'], tensorC)
self.assertEqual(vectors_obj['not_in_it'], expected_unk_tensor)

def test_vectors_load_and_save(self):
tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)
expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)

tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs)

with self.subTest('pybind'):
vector_path = os.path.join(self.test_dir, 'vectors_pybind.pt')
torch.save(vectors_obj, vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
self.assertEqual(loaded_vectors_obj['b'], tensorB)
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)

with self.subTest('torchscript'):
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
self.assertEqual(loaded_vectors_obj['b'], tensorC)
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
self.assertEqual(loaded_vectors_obj['a'], tensorA)
self.assertEqual(loaded_vectors_obj['b'], tensorB)
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
Expand Down
19 changes: 13 additions & 6 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,19 @@ def test_vocab_load_and_save(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

vocab_path = os.path.join(self.test_dir, 'vocab.pt')
torch.save(v.to_ivalue(), vocab_path)
loaded_v = torch.load(vocab_path)

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
with self.subTest('pybind'):
vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
torch.save(v, vocab_path)
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

with self.subTest('torchscript'):
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
torch.save(v.to_ivalue(), vocab_path)
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
Expand Down
8 changes: 8 additions & 0 deletions torchtext/csrc/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,12 @@ std::string Regex::Sub(std::string str, const std::string &repl) const {
return str;
}

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self) {
return self->re_str_;
}

c10::intrusive_ptr<Regex> _deserialize_regex(std::string &&state) {
return c10::make_intrusive<Regex>(std::move(state));
}

} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ struct Regex : torch::CustomClassHolder {
Regex(const std::string &re_str);
std::string Sub(std::string str, const std::string &repl) const;
};

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self);
c10::intrusive_ptr<Regex> _deserialize_regex(std::string &&state);

} // namespace torchtext
11 changes: 11 additions & 0 deletions torchtext/csrc/regex_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,15 @@ void RegexTokenizer::split_(std::string &str, std::vector<std::string> &tokens,
}
}

RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr<RegexTokenizer> &self) {
return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_);
}

c10::intrusive_ptr<RegexTokenizer> _deserialize_regex_tokenizer(RegexTokenizerStates &&states) {
return c10::make_intrusive<RegexTokenizer>(
std::move(std::get<0>(states)),
std::move(std::get<1>(states)),
std::get<2>(states));
}

} // namespace torchtext
6 changes: 6 additions & 0 deletions torchtext/csrc/regex_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace torchtext {

typedef std::tuple<std::vector<std::string>, std::vector<std::string>, bool>
RegexTokenizerStates;

struct RegexTokenizer : torch::CustomClassHolder {
private:
std::vector<RE2 *> compiled_patterns_;
Expand All @@ -20,4 +23,7 @@ struct RegexTokenizer : torch::CustomClassHolder {
std::vector<std::string> forward(std::string str) const;
};

RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr<RegexTokenizer> &self);
c10::intrusive_ptr<RegexTokenizer> _deserialize_regex_tokenizer(RegexTokenizerStates &&states);

} // namespace torchtext
Loading

0 comments on commit aa7176b

Please sign in to comment.