-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtest_DNNBase.py
42 lines (35 loc) · 1.23 KB
/
test_DNNBase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from unittest import TestCase
import numpy as np
from dnn_base import DNNBase
# -*- coding: UTF-8 -*-
class TestDNNBase(TestCase):
def setUp(self):
self.dnn_base = DNNBase()
def test_viterbi(self):
score = np.arange(10, 170, 10).reshape(4, 4).T
A = np.array([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0]])
init_A = np.array([1, 1, 0, 0])
labels = np.array([3, 3, 3, 3])
current_path = self.dnn_base.viterbi(score, A, init_A)
print(current_path)
def test_viterbi_new(self):
score = np.arange(10, 170, 10).reshape(4, 4).T
A = np.array([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0]])
init_A = np.array([1, 1, 0, 0])
labels = np.array([3,3,3,3])
current_path = self.dnn_base.viterbi_new(score, A, init_A,labels)
#print(current_path)
#correct_path = np.array([1, 3, 1, 3])
#correct_score = np.array([21, 102, 203, 364])
#self.assertTrue(np.all(current_path == correct_path))
#self.assertTrue(np.all(current_score == correct_score))
def test_generate_transition_update(self):
pass
def test_generate_transition_update_index(self):
pass
def test_sentence2index(self):
pass
def test_index2seq(self):
pass
def test_tags2words(self):
pass