Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TREC DL and LLM AggreFact experiments for relevance benchmark + prompts comparisons and groundedness vs Bespoke Minicheck 7B #1660

Merged
merged 23 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
save
  • Loading branch information
sfc-gh-dhuang committed Nov 25, 2024
commit 65d92cea4687a27d7bffcaa13003f25f1c65cfa8
Original file line number Diff line number Diff line change
Expand Up @@ -196,45 +196,6 @@
"session.reset_database()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from trulens.core import Provider\n",
"\n",
"\n",
"THRESHOLD = 0.33\n",
"class CustomTermFeedback(Provider):\n",
" def true_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 1 else 0.0\n",
" \n",
" def true_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 0 else 0.0\n",
"\n",
" def false_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 0 else 0.0\n",
" def false_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 1 else 0.0\n",
" \n",
" def term_absolute_error(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" return abs(feedback_score - gt_score) "
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -790,15 +751,17 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"from trulens.benchmark.benchmark_frameworks.experiments.dataset_preprocessing import (\n",
" generate_ms_marco_trec_dl_annotation_benchmark,\n",
")\n",
"\n",
"trec_doc_2022 = list(generate_ms_marco_trec_dl_annotation_benchmark(dataset_path=\"msmarco-document-v2/trec-dl-2022\", max_samples_per_bucket=150))"
"# trec_doc_2022 = list(generate_ms_marco_trec_dl_annotation_benchmark(dataset_path=\"msmarco-document-v2/trec-dl-2022\", max_samples_per_bucket=150))\n",
"\n",
"trec_passage_2022 = list(generate_ms_marco_trec_dl_annotation_benchmark(dataset_path=\"msmarco-passage-v2/trec-dl-2022\", max_samples_per_bucket=150))"
]
},
{
Expand All @@ -818,7 +781,7 @@
}
],
"source": [
"len(trec_doc_2022)"
"len(trec_passage_2022)"
]
},
{
Expand All @@ -827,7 +790,7 @@
"metadata": {},
"outputs": [],
"source": [
"trec_doc_2022_true_labels = [entry[\"expected_score\"] for entry in trec_doc_2022]"
"trec_passage_2022_true_labels = [entry[\"expected_score\"] for entry in trec_passage_2022]"
]
},
{
Expand All @@ -851,7 +814,7 @@
" visualize_expected_score_distribution,\n",
")\n",
"\n",
"visualize_expected_score_distribution(trec_doc_2022_true_labels)"
"visualize_expected_score_distribution(trec_passage_2022_true_labels)"
]
},
{
Expand Down Expand Up @@ -937,6 +900,43 @@
"from trulens.core import Feedback\n",
"\n",
"\n",
"from trulens.core import Provider\n",
"\n",
"\n",
"THRESHOLD = 0.5 # for passage retrieval annotation, we consider a score of 0.5 or above as relevant\n",
"class CustomTermFeedback(Provider):\n",
" def true_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 1 else 0.0\n",
" \n",
" def true_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 0 else 0.0\n",
"\n",
" def false_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 0 else 0.0\n",
" def false_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 1 else 0.0\n",
" \n",
" def term_absolute_error(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" return abs(feedback_score - gt_score) \n",
" \n",
" def raw_gt_score(self, output: str) -> float:\n",
" return float(output.split(\";\")[1])\n",
" \n",
" def raw_feedback_score(self, output: str) -> float:\n",
" return float(output.split(\";\")[0])\n",
"\n",
"custom_term_feedback = CustomTermFeedback()\n",
"\n",
Expand All @@ -945,14 +945,15 @@
"f_fp = Feedback(custom_term_feedback.false_positive, name=\"False Positive\", higher_is_better=False).on_output()\n",
"f_fn = Feedback(custom_term_feedback.false_negative, name=\"False Negative\", higher_is_better=False).on_output()\n",
"f_abs_err = Feedback(custom_term_feedback.term_absolute_error, name=\"Absolute Error\", higher_is_better=False).on_output()\n",
"f_raw_gt_score = Feedback(custom_term_feedback.raw_gt_score, name=\"Raw GT Score\", higher_is_better=True).on_output()\n",
"f_raw_feedback_score = Feedback(custom_term_feedback.raw_feedback_score, name=\"Raw Feedback Score\", higher_is_better=True).on_output()\n",
"\n",
"\n",
"CUSTOM_FEEDBACK_FUNCS = [f_tp, f_tn, f_fp, f_fn, f_abs_err]\n",
"CUSTOM_FEEDBACK_FUNCS = [f_tp, f_tn, f_fp, f_fn, f_abs_err, f_raw_gt_score, f_raw_feedback_score]\n",
"\n",
"def run_experiment_for_provider(provider, dataset_df):\n",
" tru_wrapped_context_relevance_app = TruBasicApp(\n",
" trulens_context_relevance,\n",
" app_name=\"trec-dl-doc-2022-11102024\",\n",
" app_name=\"trec-dl-passage-2022\",\n",
" app_version=f\"{provider.model_engine}-context-relevance\",\n",
" feedbacks=CUSTOM_FEEDBACK_FUNCS,\n",
" )\n",
Expand All @@ -973,7 +974,7 @@
"\n",
"for provider in PROVIDERS:\n",
" print(f\"Running provider: {provider.model_engine}\")\n",
" run_experiment_for_provider(provider, trec_doc_2022)\n",
" run_experiment_for_provider(provider, trec_passage_2022)\n",
"\n",
"# with concurrent.futures.ThreadPoolExecutor() as executor:\n",
"# futures = [executor.submit(run_experiment_for_provider, provider, trec_doc_2022) for provider in PROVIDERS]\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def generate_balanced_ms_marco_hard_negatives_dataset(


def generate_ms_marco_trec_dl_annotation_benchmark(
dataset_path: str = "msmarco-document-v2/trec-dl-2022",
dataset_path: str = "msmarco-passage-v2/trec-dl-2022",
max_samples_per_bucket: int = 100,
):
dataset = ir_datasets.load(dataset_path)
Expand Down