forked from apachecn/ailearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
30fd5c7
commit 41cc04a
Showing
7 changed files
with
329 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import pickle | ||
import numpy as np | ||
import platform | ||
from collections import Counter | ||
|
||
from keras.models import Sequential | ||
from keras.layers import Embedding, Bidirectional, LSTM | ||
from keras_contrib.layers import CRF | ||
from keras.preprocessing.sequence import pad_sequences | ||
|
||
EMBED_DIM = 200 | ||
BiRNN_UNITS = 200 | ||
|
||
|
||
|
||
def load_data(): | ||
train = _parse_data(open('zh-NER/data/train_data.data', 'rb')) | ||
test = _parse_data(open('zh-NER/data/test_data.data', 'rb')) | ||
|
||
word_counts = Counter(row[0].lower() for sample in train for row in sample) | ||
vocab = [w for w, f in iter(word_counts.items()) if f >= 2] | ||
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"] | ||
|
||
# save initial config data | ||
with open('zh-NER/model/config.pkl', 'wb') as outp: | ||
pickle.dump((vocab, chunk_tags), outp) | ||
|
||
train = _process_data(train, vocab, chunk_tags) | ||
test = _process_data(test, vocab, chunk_tags) | ||
return train, test, (vocab, chunk_tags) | ||
|
||
|
||
def _parse_data(fh): | ||
# in windows the new line is '\r\n\r\n' the space is '\r\n' . so if you use windows system, | ||
# you have to use recorsponding instructions | ||
|
||
if platform.system() == 'Windows': | ||
split_text = '\r\n' | ||
else: | ||
split_text = '\n' | ||
|
||
string = fh.read().decode('utf-8') | ||
data = [[row.split() for row in sample.split(split_text)] for | ||
sample in | ||
string.strip().split(split_text + split_text)] | ||
fh.close() | ||
return data | ||
|
||
|
||
def _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False): | ||
if maxlen is None: | ||
maxlen = max(len(s) for s in data) | ||
word2idx = dict((w, i) for i, w in enumerate(vocab)) | ||
x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to <unk> (index 1) if not in vocab | ||
|
||
y_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data] | ||
|
||
x = pad_sequences(x, maxlen) # left padding | ||
|
||
y_chunk = pad_sequences(y_chunk, maxlen, value=-1) | ||
|
||
if onehot: | ||
y_chunk = np.eye(len(chunk_tags), dtype='float32')[y_chunk] | ||
else: | ||
y_chunk = np.expand_dims(y_chunk, 2) | ||
return x, y_chunk | ||
|
||
|
||
def process_data(data, vocab, maxlen=100): | ||
word2idx = dict((w, i) for i, w in enumerate(vocab)) | ||
x = [word2idx.get(w[0].lower(), 1) for w in data] | ||
length = len(x) | ||
x = pad_sequences([x], maxlen) # left padding | ||
return x, length | ||
|
||
|
||
def create_model(train=True): | ||
if train: | ||
(train_x, train_y), (test_x, test_y), (vocab, chunk_tags) = load_data() | ||
else: | ||
with open('model/config.pkl', 'rb') as inp: | ||
(vocab, chunk_tags) = pickle.load(inp) | ||
model = Sequential() | ||
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding | ||
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True))) | ||
crf = CRF(len(chunk_tags), sparse_target=True) | ||
model.add(crf) | ||
model.summary() | ||
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) | ||
if train: | ||
return model, (train_x, train_y), (test_x, test_y) | ||
else: | ||
return model, (vocab, chunk_tags) | ||
|
||
|
||
def train(): | ||
EPOCHS = 10 | ||
model, (train_x, train_y), (test_x, test_y) = create_model() | ||
# train model | ||
model.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y]) | ||
model.save('model/crf.h5') | ||
|
||
def test(): | ||
model, (vocab, chunk_tags) = create_model(train=False) | ||
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚' | ||
str, length = process_data(predict_text, vocab) | ||
model.load_weights('model/crf.h5') | ||
raw = model.predict(str)[0][-length:] | ||
result = [np.argmax(row) for row in raw] | ||
result_tags = [chunk_tags[i] for i in result] | ||
|
||
per, loc, org = '', '', '' | ||
|
||
for s, t in zip(predict_text, result_tags): | ||
if t in ('B-PER', 'I-PER'): | ||
per += ' ' + s if (t == 'B-PER') else s | ||
if t in ('B-ORG', 'I-ORG'): | ||
org += ' ' + s if (t == 'B-ORG') else s | ||
if t in ('B-LOC', 'I-LOC'): | ||
loc += ' ' + s if (t == 'B-LOC') else s | ||
|
||
print(['person:' + per, 'location:' + loc, 'organzation:' + org]) | ||
|
||
|
||
if __name__ == "__main__": | ||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# zh-NER-keras | ||
> this project is a sample for Chinese Named Entity Recognition(NER) | ||
by Keras 2.1.4 | ||
|
||
## requirements | ||
* keras=>2.1.4 | ||
* keras contribute 2.0.8 (https://github.com/keras-team/keras-contrib) | ||
* h5py | ||
* pickle | ||
|
||
## demo | ||
|
||
```python | ||
|
||
python val.py | ||
|
||
``` | ||
|
||
|
||
input: | ||
```text | ||
中华人民共和国国务院总理周恩来在外交部长陈毅, | ||
副部长王东的陪同下, | ||
连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚 | ||
``` | ||
output: | ||
```python | ||
['person: 周恩来 陈毅, 王东', 'location: 埃塞俄比亚 非洲 阿尔巴尼亚', 'organzation: 中华人民共和国国务院 外交部'] | ||
|
||
``` | ||
|
||
|
||
|
||
|
||
|
27 changes: 27 additions & 0 deletions
27
src/py3.x/tensorflow2.x/zh-NER-keras-master/bilsm_crf_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from keras.models import Sequential | ||
from keras.layers import Embedding, Bidirectional, LSTM | ||
from keras_contrib.layers import CRF | ||
import process_data | ||
import pickle | ||
|
||
EMBED_DIM = 200 | ||
BiRNN_UNITS = 200 | ||
|
||
|
||
def create_model(train=True): | ||
if train: | ||
(train_x, train_y), (test_x, test_y), (vocab, chunk_tags) = process_data.load_data() | ||
else: | ||
with open('model/config.pkl', 'rb') as inp: | ||
(vocab, chunk_tags) = pickle.load(inp) | ||
model = Sequential() | ||
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding | ||
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True))) | ||
crf = CRF(len(chunk_tags), sparse_target=True) | ||
model.add(crf) | ||
model.summary() | ||
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) | ||
if train: | ||
return model, (train_x, train_y), (test_x, test_y) | ||
else: | ||
return model, (vocab, chunk_tags) |
66 changes: 66 additions & 0 deletions
66
src/py3.x/tensorflow2.x/zh-NER-keras-master/process_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy | ||
from collections import Counter | ||
from keras.preprocessing.sequence import pad_sequences | ||
import pickle | ||
import platform | ||
|
||
|
||
def load_data(): | ||
train = _parse_data(open('data/train_data.data', 'rb')) | ||
test = _parse_data(open('data/test_data.data', 'rb')) | ||
|
||
word_counts = Counter(row[0].lower() for sample in train for row in sample) | ||
vocab = [w for w, f in iter(word_counts.items()) if f >= 2] | ||
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"] | ||
|
||
# save initial config data | ||
with open('model/config.pkl', 'wb') as outp: | ||
pickle.dump((vocab, chunk_tags), outp) | ||
|
||
train = _process_data(train, vocab, chunk_tags) | ||
test = _process_data(test, vocab, chunk_tags) | ||
return train, test, (vocab, chunk_tags) | ||
|
||
|
||
def _parse_data(fh): | ||
# in windows the new line is '\r\n\r\n' the space is '\r\n' . so if you use windows system, | ||
# you have to use recorsponding instructions | ||
|
||
if platform.system() == 'Windows': | ||
split_text = '\r\n' | ||
else: | ||
split_text = '\n' | ||
|
||
string = fh.read().decode('utf-8') | ||
data = [[row.split() for row in sample.split(split_text)] for | ||
sample in | ||
string.strip().split(split_text + split_text)] | ||
fh.close() | ||
return data | ||
|
||
|
||
def _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False): | ||
if maxlen is None: | ||
maxlen = max(len(s) for s in data) | ||
word2idx = dict((w, i) for i, w in enumerate(vocab)) | ||
x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to <unk> (index 1) if not in vocab | ||
|
||
y_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data] | ||
|
||
x = pad_sequences(x, maxlen) # left padding | ||
|
||
y_chunk = pad_sequences(y_chunk, maxlen, value=-1) | ||
|
||
if onehot: | ||
y_chunk = numpy.eye(len(chunk_tags), dtype='float32')[y_chunk] | ||
else: | ||
y_chunk = numpy.expand_dims(y_chunk, 2) | ||
return x, y_chunk | ||
|
||
|
||
def process_data(data, vocab, maxlen=100): | ||
word2idx = dict((w, i) for i, w in enumerate(vocab)) | ||
x = [word2idx.get(w[0].lower(), 1) for w in data] | ||
length = len(x) | ||
x = pad_sequences([x], maxlen) # left padding | ||
return x, length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import bilsm_crf_model | ||
|
||
EPOCHS = 10 | ||
model, (train_x, train_y), (test_x, test_y) = bilsm_crf_model.create_model() | ||
# train model | ||
model.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y]) | ||
model.save('model/crf.h5') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import bilsm_crf_model | ||
import process_data | ||
import numpy as np | ||
|
||
model, (vocab, chunk_tags) = bilsm_crf_model.create_model(train=False) | ||
predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚' | ||
str, length = process_data.process_data(predict_text, vocab) | ||
model.load_weights('model/crf.h5') | ||
raw = model.predict(str)[0][-length:] | ||
result = [np.argmax(row) for row in raw] | ||
result_tags = [chunk_tags[i] for i in result] | ||
|
||
per, loc, org = '', '', '' | ||
|
||
for s, t in zip(predict_text, result_tags): | ||
if t in ('B-PER', 'I-PER'): | ||
per += ' ' + s if (t == 'B-PER') else s | ||
if t in ('B-ORG', 'I-ORG'): | ||
org += ' ' + s if (t == 'B-ORG') else s | ||
if t in ('B-LOC', 'I-LOC'): | ||
loc += ' ' + s if (t == 'B-LOC') else s | ||
|
||
print(['person:' + per, 'location:' + loc, 'organzation:' + org]) |