Skip to content

Commit

Permalink
v0.1.1
Browse files Browse the repository at this point in the history
+ add __init__ for Vocab
  • Loading branch information
kduxin committed Oct 12, 2022
1 parent c1c80f9 commit 018aeee
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 49 deletions.
4 changes: 2 additions & 2 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "corpusit-python"
version = "0.1.0"
version = "0.1.1"
authors = ["duxin.ac@gmail.com"]
edition = "2021"

Expand All @@ -19,7 +19,7 @@ tqdm = "0.4.2"
ndarray = "0.15.6"

[dependencies.corpusit]
version = "*"
version = "0.1.1"
path = "../../corpusit"

[build-dependencies]
Expand Down
39 changes: 34 additions & 5 deletions bindings/python/corpusit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ class Vocab:
""" mapping from word string to count"""
counts: Mapping[str, int]

def __init__(
self,
i2s: Mapping[int, str],
i2count: Mapping[int, int],
unk: str = None,
other_special_name2str: Mapping[str, str] = None,
) -> Vocab:
"""Create a vocabulary.
Args:
i2s (Mapping[int, str]): mapping from word index to string
i2count (Mapping[int, int]): mapping from word index to count
unk (str, optional): The `unknown` word. Defaults to None.
other_special_name2str (Mapping[str, str], optional):
Special words in addition to `unk`, a mapping from
name (`eos`, etc.) to word string (e.g., `<eos>`).
Defaults to None.
Returns:
Vocab
"""
@staticmethod
def from_json(
path_to_json, min_count: int = None, max_size: int = None, unk: str = None
Expand Down Expand Up @@ -149,7 +170,7 @@ class SkipGramDataset:
- n_neg (int, optional): how many negative samples for one positive sample.
Defaults to 1.
"""
def positive_sampler(self, batch_size: int , seed: uint = 0, num_threads: uint = 4):
def positive_sampler(self, batch_size: int, seed: uint = 0, num_threads: uint = 4):
"""Create an iterable (maybe multi-thread) sampler for generating "positive" samples.
The sampler returns a numpy 2-d array of shape (batch_size, 2) in each iteration,
where each row is a pair of word indices.
Expand Down Expand Up @@ -184,16 +205,24 @@ class SkipGramIter:

class S2I:
"""A Mapping from word string to index."""
def get(self, s: str, default) -> int: pass

def get(self, s: str, default) -> int:
pass

class I2S:
"""A Mapping from word index to string."""
def get(self, i: int, default) -> str: pass

def get(self, i: int, default) -> str:
pass

class I2Count:
"""A Mapping from word index to count."""
def get(self, i: int, default) -> int: pass

def get(self, i: int, default) -> int:
pass

class Counts:
"""A Mapping from word str to count."""
def get(self, s: str, default) -> int: pass

def get(self, s: str, default) -> int:
pass
133 changes: 98 additions & 35 deletions bindings/python/src/vocab.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
use bincode;
use pyo3::{prelude::*, types::PyBytes};
use serde_json;
use std::{fs, io};
use std::collections::{
HashMap,
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use pyo3::{
prelude::*,
types::PyBytes,
};
use std::{fs, io};

extern crate corpusit;
use corpusit::Vocab;
Expand Down Expand Up @@ -37,11 +32,25 @@ impl From<Vocab> for PyVocab {
impl From<Arc<RwLock<Vocab>>> for PyVocab {
fn from(vocab: Arc<RwLock<Vocab>>) -> Self {
let vocab = vocab;
let s2i = S2I {vocab: Arc::clone(&vocab)};
let i2s = I2S {vocab: Arc::clone(&vocab)};
let counts = Counts {vocab: Arc::clone(&vocab)};
let i2count = I2Count {vocab: Arc::clone(&vocab)};
Self { vocab: Arc::clone(&vocab), s2i: s2i, i2s: i2s, counts: counts, i2count: i2count}
let s2i = S2I {
vocab: Arc::clone(&vocab),
};
let i2s = I2S {
vocab: Arc::clone(&vocab),
};
let counts = Counts {
vocab: Arc::clone(&vocab),
};
let i2count = I2Count {
vocab: Arc::clone(&vocab),
};
Self {
vocab: Arc::clone(&vocab),
s2i: s2i,
i2s: i2s,
counts: counts,
i2count: i2count,
}
}
}

Expand All @@ -58,6 +67,31 @@ impl PyVocab {

#[pymethods]
impl PyVocab {
#[args(unk = "None", other_special_name2str = "None")]
#[new]
pub fn new(
i2s: HashMap<usize, String>,
i2count: HashMap<usize, u64>,
unk: Option<String>,
other_special_name2str: Option<HashMap<String, String>>,
) -> Self {
let mut other_special_name2str = other_special_name2str.unwrap_or(HashMap::default());
if let Some(unk_str) = unk {
other_special_name2str
.entry("unk".to_string())
.and_modify(|unk_s| {
if unk_s.to_string() != unk_str {
panic!(
"You specified two different tokens (`{}` and `{}`) as {{unk}}.",
&unk_str, unk_s
);
}
})
.or_insert(unk_str);
};
Self::from(Vocab::new(i2s, i2count, other_special_name2str))
}

/// Read a Vocab stored in a json file at `path_to_json`
/// Parameters
/// - min_count: set a new count threshold. All words with smaller
Expand Down Expand Up @@ -115,10 +149,8 @@ impl PyVocab {
/// - path_to_save_bin: if not specified, will save at
/// ${path_to_corpus}.vocab.bin
#[args(min_count = "5", max_size = "None", unk = "\"<unk>\"")]
#[pyo3(
text_signature = "(path_to_corpus, min_count=None, max_size=None, \
unk=None, path_to_save_json=None, path_to_save_bin=None)"
)]
#[pyo3(text_signature = "(path_to_corpus, min_count=None, max_size=None, \
unk=None, path_to_save_json=None, path_to_save_bin=None)")]
#[staticmethod]
pub fn build(
path_to_corpus: &str,
Expand Down Expand Up @@ -175,15 +207,21 @@ impl PyVocab {

pub fn counts_dict(slf: PyRef<Self>) -> HashMap<String, u64> {
let vocab = slf.vocab.read().unwrap();
vocab.i2count.iter().map(|(i, c)| {
(vocab.i2s[i].clone(), *c)
}).collect()
vocab
.i2count
.iter()
.map(|(i, c)| (vocab.i2s[i].clone(), *c))
.collect()
}

pub fn get_special_tokens(slf: PyRef<Self>) -> HashMap<String, String> {
slf.vocab.read().unwrap().special_name2i.iter().map(|(name, id)| {
(name.to_string(), slf.vocab.read().unwrap().i2s[id].clone())
}).collect()
slf.vocab
.read()
.unwrap()
.special_name2i
.iter()
.map(|(name, id)| (name.to_string(), slf.vocab.read().unwrap().i2s[id].clone()))
.collect()
}

pub fn __contains__(slf: PyRef<Self>, s: &str) -> bool {
Expand All @@ -194,7 +232,13 @@ impl PyVocab {
}

pub fn keys(slf: PyRef<Self>) -> Vec<String> {
slf.vocab.read().unwrap().s2i.keys().map(|s| s.to_string()).collect()
slf.vocab
.read()
.unwrap()
.s2i
.keys()
.map(|s| s.to_string())
.collect()
}

pub fn __repr__(slf: PyRef<Self>) -> String {
Expand All @@ -205,12 +249,20 @@ impl PyVocab {
}
special_tokens_str.pop();
special_tokens_str.pop();
format!("<Vocab(size={}, special_tokens={{{}}})", vocab.i2s.len(), special_tokens_str)
format!(
"<Vocab(size={}, special_tokens={{{}}})",
vocab.i2s.len(),
special_tokens_str
)
}

#[getter]
pub fn unk(slf: PyRef<Self>) -> Option<String> {
slf.vocab.read().unwrap().unk_str().and_then(|s| Some(s.to_string()))
slf.vocab
.read()
.unwrap()
.unk_str()
.and_then(|s| Some(s.to_string()))
}

#[getter]
Expand All @@ -237,10 +289,8 @@ impl PyVocab {
pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.vocab).unwrap()).to_object(py))
}

}


#[derive(Clone)]
#[pyclass]
pub struct S2I {
Expand Down Expand Up @@ -268,7 +318,12 @@ pub struct Counts {
#[pymethods]
impl S2I {
pub fn __getitem__(slf: PyRef<Self>, s: &str) -> Option<usize> {
slf.vocab.read().unwrap().s2i.get(s).and_then(|id| Some(*id))
slf.vocab
.read()
.unwrap()
.s2i
.get(s)
.and_then(|id| Some(*id))
}

pub fn get(slf: PyRef<Self>, s: &str, default: usize) -> usize {
Expand All @@ -286,7 +341,12 @@ impl S2I {
#[pymethods]
impl I2S {
pub fn __getitem__(slf: PyRef<Self>, id: usize) -> Option<String> {
slf.vocab.read().unwrap().i2s.get(&id).and_then(|s| Some(s.to_string()))
slf.vocab
.read()
.unwrap()
.i2s
.get(&id)
.and_then(|s| Some(s.to_string()))
}

pub fn get(slf: PyRef<Self>, id: usize, default: String) -> String {
Expand All @@ -304,7 +364,12 @@ impl I2S {
#[pymethods]
impl I2Count {
pub fn __getitem__(slf: PyRef<Self>, id: usize) -> Option<u64> {
slf.vocab.read().unwrap().i2count.get(&id).and_then(|c| Some(*c))
slf.vocab
.read()
.unwrap()
.i2count
.get(&id)
.and_then(|c| Some(*c))
}

pub fn get(slf: PyRef<Self>, id: usize, default: u64) -> u64 {
Expand All @@ -323,9 +388,7 @@ impl I2Count {
impl Counts {
pub fn __getitem__(slf: PyRef<Self>, s: &str) -> Option<u64> {
let vocab = slf.vocab.read().unwrap();
vocab.s2i.get(s).and_then(|id| {
Some(vocab.i2count[id])
})
vocab.s2i.get(s).and_then(|id| Some(vocab.i2count[id]))
}

pub fn get(slf: PyRef<Self>, s: &str, default: u64) -> u64 {
Expand All @@ -338,4 +401,4 @@ impl Counts {
pub fn __len__(slf: PyRef<Self>) -> usize {
slf.vocab.read().unwrap().i2s.len()
}
}
}
2 changes: 1 addition & 1 deletion corpusit/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "corpusit"
version = "0.1.0"
version = "0.1.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
35 changes: 29 additions & 6 deletions corpusit/src/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ impl VocabBuilder {
}
counts
})
.reduce(
|| HashMap::<String, u64>::new(),
_merge_counts,
);
.reduce(|| HashMap::<String, u64>::new(), _merge_counts);
wordcounts = _merge_counts(wordcounts, chunk_counts);
}

Expand Down Expand Up @@ -148,6 +145,34 @@ pub struct Vocab {
}

impl Vocab {
pub fn new(
i2s: HashMap<usize, String>,
i2count: HashMap<usize, u64>,
special_name2str: HashMap<String, String>,
) -> Self {
let s2i: HashMap<String, usize> = i2s.iter().map(|(id, s)| (s.to_string(), *id)).collect();
let totalcount: u64 = i2count.iter().map(|(_, c)| *c).sum();
let special_name2i: HashMap<String, usize> = special_name2str
.iter()
.map(|(name, s)| {
(
name.to_string(),
match s2i.get(s) {
Some(id) => *id,
None => panic!("Special token `{}` not found in vocabulary.", s),
},
)
})
.collect();
Self {
i2s: i2s,
s2i: s2i,
i2count: i2count,
totalcount: totalcount,
special_name2i: special_name2i,
}
}

pub fn unk(&self) -> Option<(&str, usize)> {
self.special_name2i
.get("unk")
Expand Down Expand Up @@ -236,7 +261,6 @@ impl Vocab {
}

pub fn truncate(&mut self, min_count: u64, max_size: usize) {

let mut trun_i2count: HashMap<usize, u64> = HashMap::new();
self.special_name2i.iter().for_each(|(_, id)| {
trun_i2count.insert(*id, self.i2count[id]);
Expand All @@ -248,7 +272,6 @@ impl Vocab {
let mut sorted_i2count: Vec<(usize, u64)> = self.i2count.clone().into_iter().collect();
sorted_i2count.sort_by(|(_, c1), (_, c2)| (u64::MAX - *c1).cmp(&(u64::MAX - *c2)));
for (id, c) in sorted_i2count.iter() {

if special_ids.contains(id) {
continue;
}
Expand Down

0 comments on commit 018aeee

Please sign in to comment.