Skip to content

Commit

Permalink
update code style
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouchengjie committed Dec 30, 2020
1 parent 98c0594 commit 9bcb1db
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/LAC/models.py
Original file line number Diff line number Diff line change
@@ -72,7 +72,8 @@ def run(self, texts):
"""文本输入经过模型转为运行结果Tensor"""
self.empty_string = []
self.empty_return = {
"crf_result": [[], [], []]
"empty": [[], [], []],
"crf_result": [[[], [], []]]
}
if isinstance(texts, list) or isinstance(texts, tuple):
self.batch = True
@@ -89,15 +90,15 @@ def run(self, texts):

if len(texts) == 0:
return {
"crf_result": [self.empty_return["crf_result"]] * len(self.empty_string)
"crf_result": self.empty_return["crf_result"] * len(self.empty_string)
}

tensor_words, words_length = self.texts2tensor(texts)
crf_decode = self.predictor.run([tensor_words])
crf_result = self.parse_result(texts, crf_decode[0], self.dataset, words_length)

for _ in self.empty_string:
crf_result.insert(_, self.empty_return["crf_result"])
crf_result.insert(_, self.empty_return["empty"])

return {
"crf_decode": crf_decode,
@@ -236,7 +237,7 @@ def __init__(self, model_path, mode, use_cuda):

def run(self, texts):
crf_result = super(LacModel, self).run(texts)['crf_result']
result = [[word, tag] for word, tag, tag_for_rank in crf_result] if self.batch else crf_result[:-1]
result = [[word, tag] for word, tag, tag_for_rank in crf_result] if self.batch else crf_result[0][:-1]
return result

def call_run(self, texts):
@@ -252,7 +253,7 @@ def __init__(self, model_path, mode, use_cuda):

def run(self, texts):
crf_result = super(SegModel, self).run(texts)["crf_result"]
result = [word for word, tag, tag_for_rank in crf_result] if self.batch else crf_result[0]
result = [word for word, tag, tag_for_rank in crf_result] if self.batch else crf_result[0][0]
return result

def texts2tensor(self, texts):
@@ -319,6 +320,7 @@ def run(self, texts):
self.lac.custom = self.custom

lac_result = self.lac.call_run(texts)
self.batch = self.lac.batch

if len(lac_result) == 1:
return lac_result["crf_result"]
@@ -340,7 +342,7 @@ def run(self, texts):
else:
result[_].append(weight.pop(0))

return result
return result if self.batch else result[0]

def parse_result(self, tags_for_rank, result, words_length):
"""将RANK模型输出的Tensor转为明文"""

0 comments on commit 9bcb1db

Please sign in to comment.