diff --git a/retrieval_lm/run_short_form.py b/retrieval_lm/run_short_form.py index 34ba0ca..49f9777 100644 --- a/retrieval_lm/run_short_form.py +++ b/retrieval_lm/run_short_form.py @@ -50,9 +50,8 @@ def postprocess_answer_option_conditioned(answer): def call_model_rerank_w_scores_batch(prompt, evidences, model, max_new_tokens=15, ret_tokens=None, rel_tokens=None, grd_tokens=None, ut_tokens=None, - use_seqscore=False, threshold=0.5, beam_width=2, max_depth=1, + use_seqscore=False, threshold=0.5, w_rel=1.0, w_sup=1.0, w_use=0.5, mode="adaptive_retrieval", closed=False): - # max_inpt_tokens = tokenizer.model_max_length if model_max_length is None else model_max_length results = {} if mode != "always_retrieve": sampling_params = SamplingParams( @@ -285,8 +284,6 @@ def main(): help="reward weight for generation support (attribution)") parser.add_argument("--w_use", type=float, default=1.0, help="reward weight for overall completeness / utility.") - parser.add_argument("--ignore_cont", action="store_true", - help="filter out sentences that include [No support / Contradictory] ") parser.add_argument('--mode', type=str, help="mode to control retrieval.", default="default", choices=['adaptive_retrieval', 'no_retrieval', 'always_retrieve'],) parser.add_argument('--metric', type=str, help="metric to be used during evaluation") @@ -316,7 +313,7 @@ def generate(prompt, evidences, max_new_tokens): return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens, rel_tokens=rel_tokens, ret_tokens=ret_tokens, grd_tokens=grd_tokens, ut_tokens=ut_tokens, threshold=args.threshold, beam_width=args.beam_width, max_depth=args.max_depth, use_seqscore=args.use_seqscore, - w_rel=1.0, w_sup=1.0, w_use=0.5, mode=args.mode, closed=args.task in ["fever", "arc_c"]) + w_rel=args.w_rel, w_sup=args.w_sup, w_use=args.w_use, mode=args.mode, closed=args.task in ["fever", "arc_c"]) preds = [] prompts = [] @@ -338,7 +335,6 @@ def generate(prompt, evidences, max_new_tokens): all_results.append(results) if do_retrieve is True: count += 1 - # golds.append(row["output"]) if "answers" not in row and "answer" in row: row["answers"] = [row["answer"]] if type( row["answer"]) is str else row["answer"]