Skip to content

Commit

Permalink
De-lint + minor refactoring for PEP8 compliance (#1668)
Browse files Browse the repository at this point in the history
  • Loading branch information
lintool authored Oct 3, 2023
1 parent c91b9d1 commit 4f3da10
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 127 deletions.
41 changes: 23 additions & 18 deletions pyserini/2cr/beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@
'scifact'
]


def format_run_command(raw):
return raw.replace('--topics', '\\\n --topics')\
.replace('--index', '\\\n --index')\
.replace('--encoder-class', '\\\n --encoder-class')\
.replace('--output ', '\\\n --output ')\
return raw.replace('--topics', '\\\n --topics') \
.replace('--index', '\\\n --index') \
.replace('--encoder-class', '\\\n --encoder-class') \
.replace('--output ', '\\\n --output ') \
.replace('--output-format trec', '\\\n --output-format trec \\\n ') \
.replace('--hits ', '\\\n --hits ')


def format_eval_command(raw):
return raw.replace('-c ', '\\\n -c ')\
return raw.replace('-c ', '\\\n -c ') \
.replace('run.', '\\\n run.')


Expand All @@ -85,16 +86,19 @@ def read_file(f):

return text

def list_conditions(args):

def list_conditions():
with open(pkg_resources.resource_filename(__name__, 'beir.yaml')) as f:
yaml_data = yaml.safe_load(f)
for condition in yaml_data['conditions']:
print(condition['name'])

def list_datasets(args):


def list_datasets():
for dataset in beir_keys:
print(dataset)


def generate_report(args):
table = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0)))
commands = defaultdict(lambda: defaultdict(lambda: ''))
Expand Down Expand Up @@ -149,8 +153,7 @@ def generate_report(args):
eval_cmd2=eval_commands[dataset]["bm25-multifield"].rstrip(),
eval_cmd3=eval_commands[dataset]["splade-distil-cocodenser-medium"].rstrip(),
eval_cmd4=eval_commands[dataset]["contriever"].rstrip(),
eval_cmd5=eval_commands[dataset]["contriever-msmarco"].rstrip(),
)
eval_cmd5=eval_commands[dataset]["contriever-msmarco"].rstrip())

html_rows.append(s)
row_cnt += 1
Expand All @@ -159,6 +162,7 @@ def generate_report(args):
with open(args.output, 'w') as out:
out.write(Template(html_template).substitute(title='BEIR', rows=all_rows))


def run_conditions(args):
start = time.time()

Expand Down Expand Up @@ -253,16 +257,17 @@ def run_conditions(args):
f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.4f}{table[dataset]["contriever-msmarco"]["R@100"]:8.4f}')
print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14)
print('avg' + ' ' * 22 + f'{final_scores["bm25-flat"]["nDCG@10"]:8.4f}{final_scores["bm25-flat"]["R@100"]:8.4f} ' +
f'{final_scores["bm25-multifield"]["nDCG@10"]:8.4f}{final_scores["bm25-multifield"]["R@100"]:8.4f} ' +
f'{final_scores["splade-distil-cocodenser-medium"]["nDCG@10"]:8.4f}{final_scores["splade-distil-cocodenser-medium"]["R@100"]:8.4f} ' +
f'{final_scores["contriever"]["nDCG@10"]:8.4f}{final_scores["contriever"]["R@100"]:8.4f} ' +
f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f}')
f'{final_scores["bm25-multifield"]["nDCG@10"]:8.4f}{final_scores["bm25-multifield"]["R@100"]:8.4f} ' +
f'{final_scores["splade-distil-cocodenser-medium"]["nDCG@10"]:8.4f}{final_scores["splade-distil-cocodenser-medium"]["R@100"]:8.4f} ' +
f'{final_scores["contriever"]["nDCG@10"]:8.4f}{final_scores["contriever"]["R@100"]:8.4f} ' +
f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f}')

end = time.time()

print('\n')
print(f'Total elapsed time: {end - start:.0f}s')

print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate regression matrix for BeIR corpora.')
# To list all conditions/datasets
Expand All @@ -282,11 +287,11 @@ def run_conditions(args):
args = parser.parse_args()

if args.list_conditions:
list_conditions(args)
list_conditions()
sys.exit()

if args.list_datasets:
list_datasets(args)
list_datasets()
sys.exit()

if args.generate_report:
Expand Down
11 changes: 5 additions & 6 deletions pyserini/2cr/miracl.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ def generate_table_rows(table, row_template, commands, eval_commands, table_id,
eval_cmd15=f'{eval_commands[keys["th"]][metric]}',
eval_cmd16=f'{eval_commands[keys["zh"]][metric]}',
eval_cmd17=f'{eval_commands[keys["de"]][metric]}',
eval_cmd18=f'{eval_commands[keys["yo"]][metric]}'
)
eval_cmd18=f'{eval_commands[keys["yo"]][metric]}')

s = s.replace("0.000", "--")
html_rows.append(s)
Expand Down Expand Up @@ -389,15 +388,15 @@ def run_conditions(args):
and split == 'dev' and metric == 'nDCG@10'
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \
(name == 'mdpr-tied-pft-msmarco-ft-all.ru'
# Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3932 -> expected 0.3933
# Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3932 -> expected 0.3933
and split == 'dev' and metric == 'nDCG@10'
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \
(name == 'bm25-mdpr-tied-pft-msmarco-hybrid.te'
# Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.6000 -> expected 0.5999
# Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.6000 -> expected 0.5999
and split == 'train' and metric == 'nDCG@10'
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \
(name == 'mcontriever-tied-pft-msmarco.id'
# Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3748 -> expected 0.3749
# Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3748 -> expected 0.3749
and split == 'train' and metric == 'nDCG@10'
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)):
result_str = okish_str
Expand All @@ -415,7 +414,7 @@ def run_conditions(args):
print_results(table, metric, split)

end = time.time()
print(f'Total elapsed time: {end - start:.0f}s')
print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr')


if __name__ == '__main__':
Expand Down
17 changes: 8 additions & 9 deletions pyserini/2cr/mrtydi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@


def format_run_command(raw):
return raw.replace('--lang', '\\\n --lang')\
.replace('--encoder', '\\\n --encoder')\
.replace('--topics', '\\\n --topics')\
.replace('--index', '\\\n --index')\
.replace('--output ', '\\\n --output ')\
return raw.replace('--lang', '\\\n --lang') \
.replace('--encoder', '\\\n --encoder') \
.replace('--topics', '\\\n --topics') \
.replace('--index', '\\\n --index') \
.replace('--output ', '\\\n --output ') \
.replace('--batch ', '\\\n --batch ') \
.replace('--threads 12', '--threads 12 \\\n ')


def format_eval_command(raw):
return raw.replace('-c ', '\\\n -c ')\
return raw.replace('-c ', '\\\n -c ') \
.replace(raw.split()[-1], f'\\\n {raw.split()[-1]}')


Expand Down Expand Up @@ -164,8 +164,7 @@ def generate_table_rows(table, row_template, commands, eval_commands, table_id,
eval_cmd8=f'{eval_commands[keys["ru"]][metric]}',
eval_cmd9=f'{eval_commands[keys["sw"]][metric]}',
eval_cmd10=f'{eval_commands[keys["te"]][metric]}',
eval_cmd11=f'{eval_commands[keys["th"]][metric]}'
)
eval_cmd11=f'{eval_commands[keys["th"]][metric]}')

html_rows.append(s)
row_cnt += 1
Expand Down Expand Up @@ -292,7 +291,7 @@ def run_conditions(args):
print_results(table, metric, split)

end = time.time()
print(f'Total elapsed time: {end - start:.0f}s')
print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr')


if __name__ == '__main__':
Expand Down
19 changes: 9 additions & 10 deletions pyserini/2cr/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def format_command(raw):
# Format hybrid commands differently.
if 'pyserini.search.hybrid' in raw:
return raw.replace('dense', '\\\n dense ') \
.replace('--encoder', '\\\n --encoder')\
.replace('--encoder', '\\\n --encoder') \
.replace('sparse', '\\\n sparse') \
.replace('fusion', '\\\n fusion') \
.replace('run --', '\\\n run --') \
Expand All @@ -282,12 +282,12 @@ def format_command(raw):
# We want these on a separate line for better readability, but note that sometimes that might
# be the end of the command, in which case we don't want to add an extra line break.
return raw.replace('--topics', '\\\n --topics') \
.replace('--threads', '\\\n --threads')\
.replace('--index', '\\\n --index')\
.replace('--output ', '\\\n --output ')\
.replace('--encoder', '\\\n --encoder')\
.replace('--onnx-encoder', '\\\n --onnx-encoder')\
.replace('--encoded-corpus', '\\\n --encoded-corpus')\
.replace('--threads', '\\\n --threads') \
.replace('--index', '\\\n --index') \
.replace('--output ', '\\\n --output ') \
.replace('--encoder', '\\\n --encoder') \
.replace('--onnx-encoder', '\\\n --onnx-encoder') \
.replace('--encoded-corpus', '\\\n --encoded-corpus') \
.replace('.txt ', '.txt \\\n ')


Expand Down Expand Up @@ -339,7 +339,7 @@ def generate_report(args):
row_id = condition['display-row'] if 'display-row' in condition else ''
cmd_template = condition['command']

row_ids[name] =row_id
row_ids[name] = row_id
table_keys[name] = display

for topic_set in condition['topics']:
Expand Down Expand Up @@ -483,7 +483,6 @@ def run_conditions(args):
topic_key = topic_set['topic_key']
eval_key = topic_set['eval_key']

short_topic_key = ''
if args.collection == 'msmarco-v1-passage' or args.collection == 'msmarco-v1-doc':
short_topic_key = find_msmarco_table_topic_set_key_v1(topic_key)
else:
Expand Down Expand Up @@ -608,7 +607,7 @@ def run_conditions(args):
end = time.time()

print('\n')
print(f'Total elapsed time: {end - start:.0f}s')
print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr')


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 4f3da10

Please sign in to comment.