Skip to content

Commit

Permalink
fix answer extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jul 10, 2024
1 parent 4ebeac5 commit a4d8a76
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion utilization/dataset/multiple_choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _post_processing_generation(self, predictions: List[str]) -> List[int]:
option_regex = [re.compile(matches.format(op=chr(ord("A") + i))) for i in range(max_option_num)]

for text, option_num in zip(predictions, self.option_nums):
label_found = option_regex[option_num - 1].findall(text.strip().split("\n")[0])
label_found = option_regex[option_num - 1].findall(text)
if not label_found:
labels.append(-1)
else:
Expand Down
13 changes: 7 additions & 6 deletions utilization/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pathlib
import sys
from collections import defaultdict
from dataclasses import fields
from functools import lru_cache
from typing import TYPE_CHECKING, List, Optional
Expand Down Expand Up @@ -37,17 +38,17 @@
"translation_dataset", "warn_once"
}

LOGGED = set()
LOGGED = defaultdict(int)


def log_once(call_log: callable, msg: str, identifier: str, stacklevel=2):
if identifier not in LOGGED:
def log_once(call_log: callable, msg: str, identifier: str, log_times=1, stacklevel=2):
if LOGGED.get(identifier, 0) < log_times:
call_log(msg, stacklevel=stacklevel)
LOGGED.add(identifier)
LOGGED[identifier] += 1


def warn_once(logger: logging.Logger, msg: str, identifier: str):
log_once(logger.warning, msg, identifier, stacklevel=3)
def warn_once(logger: logging.Logger, msg: str, identifier: str, log_times=1):
log_once(logger.warning, msg, identifier, log_times=log_times, stacklevel=3)


@lru_cache
Expand Down

0 comments on commit a4d8a76

Please sign in to comment.