Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[非互換] コストの計算を利用時にするように変更。 #275

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
コストの計算を利用時にするように変更。
今後の #252 をやりやすくするための施策。

モデルのフォーマットが変わるので注意。
  • Loading branch information
tokuhirom committed Jan 31, 2023
commit f820160833e09b9a08eafe1b6f256cc3464effc8
7 changes: 2 additions & 5 deletions akaza-data/src/subcmd/learn_corpus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ impl LearningService {
}
let system_unigram_lm = Rc::new(OnMemorySystemUnigramLM::new(
Rc::new(RefCell::new(unigram_map)),
src_system_unigram_lm.get_default_cost(),
src_system_unigram_lm.get_default_cost_for_short(),
src_system_unigram_lm.total_words,
src_system_unigram_lm.unique_words,
));
Expand Down Expand Up @@ -190,9 +188,8 @@ impl LearningService {
}
// ↓本来なら現在のデータで再調整すべきだが、一旦元のものを使う。
// TODO あとで整理する
unigram_builder.set_default_cost(self.system_unigram_lm.get_default_cost());
unigram_builder
.set_default_cost_for_short(self.system_unigram_lm.get_default_cost_for_short());
unigram_builder.set_unique_words(self.system_unigram_lm.unique_words);
unigram_builder.set_total_words(self.system_unigram_lm.total_words);
info!("Save unigram to {}", dst_unigram);
unigram_builder.save(dst_unigram)?;
Ok(())
Expand Down
19 changes: 4 additions & 15 deletions akaza-data/src/wordcnt/wordcnt_unigram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ impl WordcntUnigramBuilder {

pub struct WordcntUnigram {
marisa: Marisa,
default_cost: f32,
default_cost_for_short: f32,
pub(crate) total_words: u32,
pub(crate) unique_words: u32,
}
Expand Down Expand Up @@ -88,26 +86,17 @@ impl WordcntUnigram {
// 単語の種類数
let unique_words = map.keys().count() as u32;

let default_cost = calc_cost(0, total_words, unique_words);
let default_cost_for_short = calc_cost(1, total_words, unique_words);

Ok(WordcntUnigram {
marisa,
default_cost,
default_cost_for_short,
total_words,
unique_words,
})
}
}

impl SystemUnigramLM for WordcntUnigram {
fn get_default_cost(&self) -> f32 {
self.default_cost
}

fn get_default_cost_for_short(&self) -> f32 {
self.default_cost_for_short
fn get_cost(&self, wordcnt: u32) -> f32 {
calc_cost(wordcnt, self.total_words, self.unique_words)
}

/// @return (word_id, score)。
Expand Down Expand Up @@ -181,8 +170,8 @@ mod tests {
);
assert_eq!(wordcnt.total_words, 45); // 単語発生数
assert_eq!(wordcnt.unique_words, 2); // ユニーク単語数
assert_eq!(wordcnt.get_default_cost(), 6.672098);
assert_eq!(wordcnt.get_default_cost_for_short(), 1.6720936);
assert_eq!(wordcnt.get_cost(0), 6.672098);
assert_eq!(wordcnt.get_cost(1), 1.6720936);

assert_eq!(wordcnt.find("私/わたし"), Some((1_i32, 1.1949753)));
assert_eq!(wordcnt.find("彼/かれ"), Some((0_i32, 0.048848562)));
Expand Down
12 changes: 6 additions & 6 deletions libakaza/src/graph/graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ mod tests {
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build(),
),
Rc::new(
Expand Down Expand Up @@ -205,8 +205,8 @@ mod tests {
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build(),
),
Rc::new(
Expand Down Expand Up @@ -238,8 +238,8 @@ mod tests {
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build(),
),
Rc::new(
Expand Down
12 changes: 6 additions & 6 deletions libakaza/src/graph/graph_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ mod tests {
// -1 0 1 2
// BOS a b c
let system_unigram_lm = MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build();
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
Expand Down Expand Up @@ -371,8 +371,8 @@ mod tests {

let mut system_unigram_lm_builder = MarisaSystemUnigramLMBuilder::default();
let system_unigram_lm = system_unigram_lm_builder
.set_default_cost(19_f32)
.set_default_cost_for_short(20_f32)
.set_unique_words(19)
.set_total_words(20)
.build();
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
Expand Down Expand Up @@ -452,8 +452,8 @@ mod tests {

let mut system_unigram_lm_builder = MarisaSystemUnigramLMBuilder::default();
let system_unigram_lm = system_unigram_lm_builder
.set_default_cost(19_f32)
.set_default_cost_for_short(20_f32)
.set_unique_words(19)
.set_total_words(20)
.build();
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
Expand Down
4 changes: 2 additions & 2 deletions libakaza/src/graph/lattice_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ impl<U: SystemUnigramLM, B: SystemBigramLM> LatticeGraph<U, B> {
// 労働者災害補償保険法 のように、システム辞書には wikipedia から採録されているが,
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
// つまり、変換後のほうが短くなるもののほうをコストを安くしておく。
self.system_unigram_lm.get_default_cost_for_short()
self.system_unigram_lm.get_cost(1)
} else {
self.system_unigram_lm.get_default_cost()
self.system_unigram_lm.get_cost(0)
};
}

Expand Down
3 changes: 1 addition & 2 deletions libakaza/src/lm/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ pub trait SystemBigramLM {
}

pub trait SystemUnigramLM {
fn get_default_cost(&self) -> f32;
fn get_default_cost_for_short(&self) -> f32;
fn get_cost(&self, wordcnt: u32) -> f32;

fn find(&self, word: &str) -> Option<(i32, f32)>;
fn as_hash_map(&self) -> HashMap<String, (i32, f32)>;
Expand Down
22 changes: 6 additions & 16 deletions libakaza/src/lm/on_memory/on_memory_system_unigram_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,20 @@ use crate::lm::base::SystemUnigramLM;
pub struct OnMemorySystemUnigramLM {
// word -> (word_id, cost)
map: Rc<RefCell<HashMap<String, (i32, u32)>>>,
pub default_cost: f32,
pub default_cost_for_short: f32,
pub total_words: u32,
pub unique_words: u32,
}

impl OnMemorySystemUnigramLM {
pub fn new(
map: Rc<RefCell<HashMap<String, (i32, u32)>>>,
default_cost: f32,
default_cost_for_short: f32,
c: u32,
v: u32,
total_words: u32,
unique_words: u32,
) -> Self {
OnMemorySystemUnigramLM {
map,
default_cost,
default_cost_for_short,
total_words: c,
unique_words: v,
total_words,
unique_words,
}
}

Expand Down Expand Up @@ -57,12 +51,8 @@ impl OnMemorySystemUnigramLM {
}

impl SystemUnigramLM for OnMemorySystemUnigramLM {
fn get_default_cost(&self) -> f32 {
self.default_cost
}

fn get_default_cost_for_short(&self) -> f32 {
self.default_cost_for_short
fn get_cost(&self, wordcnt: u32) -> f32 {
calc_cost(wordcnt, self.total_words, self.unique_words)
}

fn find(&self, word: &str) -> Option<(i32, f32)> {
Expand Down
53 changes: 25 additions & 28 deletions libakaza/src/lm/system_unigram_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use log::info;

use marisa_sys::{Keyset, Marisa};

use crate::cost::calc_cost;
use crate::lm::base::SystemUnigramLM;

/*
Expand All @@ -14,8 +15,8 @@ use crate::lm::base::SystemUnigramLM;
packed float # score: 4 bytes
*/

const DEFAULT_COST_FOR_SHORT_KEY: &str = "__DEFAULT_COST_FOR_SHORT__";
const DEFAULT_COST_KEY: &str = "__DEFAULT_COST__";
const UNIQUE_WORDS_KEY: &str = "__UNIQUE_WORDS__";
const TOTAL_WORDS_KEY: &str = "__TOTAL_WORDS__";

/**
* unigram 言語モデル。
Expand Down Expand Up @@ -48,13 +49,13 @@ impl MarisaSystemUnigramLMBuilder {
keyset
}

pub fn set_default_cost_for_short(&mut self, cost: f32) -> &mut Self {
self.add(DEFAULT_COST_FOR_SHORT_KEY, cost);
pub fn set_total_words(&mut self, total_words: u32) -> &mut Self {
self.add(TOTAL_WORDS_KEY, total_words as f32);
self
}

pub fn set_default_cost(&mut self, cost: f32) -> &mut Self {
self.add(DEFAULT_COST_KEY, cost);
pub fn set_unique_words(&mut self, unique_words: u32) -> &mut Self {
self.add(UNIQUE_WORDS_KEY, unique_words as f32);
self
}

Expand All @@ -68,22 +69,22 @@ impl MarisaSystemUnigramLMBuilder {
pub fn build(&self) -> MarisaSystemUnigramLM {
let mut marisa = Marisa::default();
marisa.build(&self.keyset());
let (_, default_cost_for_short) =
MarisaSystemUnigramLM::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY).unwrap();
let (_, default_cost) =
MarisaSystemUnigramLM::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY).unwrap();
let (_, total_words) =
MarisaSystemUnigramLM::find_from_trie(&marisa, TOTAL_WORDS_KEY).unwrap();
let (_, unique_words) =
MarisaSystemUnigramLM::find_from_trie(&marisa, UNIQUE_WORDS_KEY).unwrap();
MarisaSystemUnigramLM {
marisa,
default_cost_for_short,
default_cost,
total_words: total_words as u32,
unique_words: unique_words as u32,
}
}
}

pub struct MarisaSystemUnigramLM {
marisa: Marisa,
default_cost_for_short: f32,
default_cost: f32,
total_words: u32,
unique_words: u32,
}

impl MarisaSystemUnigramLM {
Expand All @@ -95,16 +96,16 @@ impl MarisaSystemUnigramLM {
info!("Reading {}", fname);
let mut marisa = Marisa::default();
marisa.load(fname)?;
let Some((_, default_cost_for_short)) = Self::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY) else {
bail!("Missing key for {}", DEFAULT_COST_FOR_SHORT_KEY);
let Some((_, total_words)) = Self::find_from_trie(&marisa, TOTAL_WORDS_KEY) else {
bail!("Missing key for {}", TOTAL_WORDS_KEY);
};
let Some((_, default_cost)) = Self::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY) else {
bail!("Missing key for {}", DEFAULT_COST_KEY);
let Some((_, unique_words)) = Self::find_from_trie(&marisa, UNIQUE_WORDS_KEY) else {
bail!("Missing key for {}", UNIQUE_WORDS_KEY);
};
Ok(MarisaSystemUnigramLM {
marisa,
default_cost_for_short,
default_cost,
total_words: total_words as u32,
unique_words: unique_words as u32,
})
}

Expand All @@ -131,12 +132,8 @@ impl MarisaSystemUnigramLM {
}

impl SystemUnigramLM for MarisaSystemUnigramLM {
fn get_default_cost(&self) -> f32 {
self.default_cost
}

fn get_default_cost_for_short(&self) -> f32 {
self.default_cost_for_short
fn get_cost(&self, wordcnt: u32) -> f32 {
calc_cost(wordcnt, self.total_words, self.unique_words)
}

/// @return (word_id, score)。
Expand Down Expand Up @@ -173,8 +170,8 @@ mod tests {
builder.add("hello", 0.4);
builder.add("world", 0.2);
builder
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_total_words(2)
.set_unique_words(2)
.save(&tmpfile)
.unwrap();

Expand Down