Skip to content

Commit

Permalink
remove unused parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AkariAsai committed Nov 18, 2023
1 parent 33e321e commit 6947748
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions retrieval_lm/run_short_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ def postprocess_answer_option_conditioned(answer):

def call_model_rerank_w_scores_batch(prompt, evidences, model, max_new_tokens=15,
ret_tokens=None, rel_tokens=None, grd_tokens=None, ut_tokens=None,
use_seqscore=False, threshold=0.5, beam_width=2, max_depth=1,
use_seqscore=False, threshold=0.5,
w_rel=1.0, w_sup=1.0, w_use=0.5, mode="adaptive_retrieval", closed=False):
# max_inpt_tokens = tokenizer.model_max_length if model_max_length is None else model_max_length
results = {}
if mode != "always_retrieve":
sampling_params = SamplingParams(
Expand Down Expand Up @@ -285,8 +284,6 @@ def main():
help="reward weight for generation support (attribution)")
parser.add_argument("--w_use", type=float, default=1.0,
help="reward weight for overall completeness / utility.")
parser.add_argument("--ignore_cont", action="store_true",
help="filter out sentences that include [No support / Contradictory] ")
parser.add_argument('--mode', type=str, help="mode to control retrieval.",
default="default", choices=['adaptive_retrieval', 'no_retrieval', 'always_retrieve'],)
parser.add_argument('--metric', type=str, help="metric to be used during evaluation")
Expand Down Expand Up @@ -316,7 +313,7 @@ def generate(prompt, evidences, max_new_tokens):
return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens,
rel_tokens=rel_tokens, ret_tokens=ret_tokens, grd_tokens=grd_tokens, ut_tokens=ut_tokens,
threshold=args.threshold, beam_width=args.beam_width, max_depth=args.max_depth, use_seqscore=args.use_seqscore,
w_rel=1.0, w_sup=1.0, w_use=0.5, mode=args.mode, closed=args.task in ["fever", "arc_c"])
w_rel=args.w_rel, w_sup=args.w_sup, w_use=args.w_use, mode=args.mode, closed=args.task in ["fever", "arc_c"])

preds = []
prompts = []
Expand All @@ -338,7 +335,6 @@ def generate(prompt, evidences, max_new_tokens):
all_results.append(results)
if do_retrieve is True:
count += 1
# golds.append(row["output"])
if "answers" not in row and "answer" in row:
row["answers"] = [row["answer"]] if type(
row["answer"]) is str else row["answer"]
Expand Down

0 comments on commit 6947748

Please sign in to comment.