-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathtest_backend_mllm.py
146 lines (112 loc) · 4.41 KB
/
test_backend_mllm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Unit tests for the MLLM backend in Annif"""
import pytest
import annif
import annif.backend
from annif.exception import NotInitializedException, NotSupportedException
def test_mllm_default_params(project):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(backend_id="mllm", config_params={}, project=project)
expected_default_params = {
"limit": 100, # from AnnifBackend class
"min_samples_leaf": 20,
"max_leaf_nodes": 1000,
"max_samples": 0.9,
}
actual_params = mllm.params
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == val
def test_mllm_train(datadir, fulltext_corpus, project):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm",
config_params={"limit": 10, "language": "fi"},
project=project,
)
mllm.train(fulltext_corpus)
assert mllm._model is not None
assert datadir.join("mllm-train.gz").exists()
assert datadir.join("mllm-train.gz").size() > 0
assert datadir.join("mllm-model.gz").exists()
assert datadir.join("mllm-model.gz").size() > 0
def test_mllm_train_cached(datadir, project):
modelfile = datadir.join("mllm-model.gz")
assert modelfile.exists()
old_size = modelfile.size()
old_mtime = modelfile.mtime()
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm",
config_params={"limit": 10, "language": "fi"},
project=project,
)
mllm.train("cached")
assert mllm._model is not None
assert modelfile.exists()
assert modelfile.size() > 0
assert modelfile.size() != old_size or modelfile.mtime() != old_mtime
def test_mllm_train_nodocuments(project, empty_corpus):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm",
config_params={"limit": 10, "language": "fi"},
project=project,
)
with pytest.raises(NotSupportedException) as excinfo:
mllm.train(empty_corpus)
assert "training backend mllm with no documents" in str(excinfo.value)
def test_mllm_suggest(project):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm", config_params={"limit": 8, "language": "fi"}, project=project
)
results = mllm.suggest(
[
"""Arkeologia on tieteenala, jota sanotaan joskus
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan."""
]
)[0]
assert len(results) > 0
assert len(results) <= 8
archaeology = project.subjects.by_uri("http://www.yso.fi/onto/yso/p1265")
assert archaeology in [result.subject_id for result in results]
def test_mllm_suggest_no_matches(project):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm", config_params={"limit": 8, "language": "fi"}, project=project
)
results = mllm.suggest(["Nothing matches this."])[0]
assert len(results) == 0
def test_mllm_hyperopt(project, fulltext_corpus):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm",
config_params={"limit": 10, "language": "fi"},
project=project,
)
optimizer = mllm.get_hp_optimizer(fulltext_corpus, metric="NDCG")
optimizer.optimize(n_trials=3, n_jobs=1, results_file=None)
def test_mllm_train_cached_no_data(datadir, project):
modelfile = datadir.join("mllm-model.gz")
assert modelfile.exists()
trainfile = datadir.join("mllm-train.gz")
trainfile.remove()
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm",
config_params={"limit": 10, "language": "fi"},
project=project,
)
with pytest.raises(NotInitializedException):
mllm.train("cached")
def test_mllm_suggest_no_model(datadir, project):
mllm_type = annif.backend.get_backend("mllm")
mllm = mllm_type(
backend_id="mllm", config_params={"limit": 8, "language": "fi"}, project=project
)
datadir.join("mllm-model.gz").remove()
with pytest.raises(NotInitializedException):
mllm.suggest("example text")