Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dhuang committed Nov 20, 2024
1 parent 95ceef0 commit 222f201
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,45 +196,6 @@
"session.reset_database()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from trulens.core import Provider\n",
"\n",
"\n",
"THRESHOLD = 0.33\n",
"class CustomTermFeedback(Provider):\n",
" def true_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 1 else 0.0\n",
" \n",
" def true_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 0 else 0.0\n",
"\n",
" def false_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 0 else 0.0\n",
" def false_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 1 else 0.0\n",
" \n",
" def term_absolute_error(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" return abs(feedback_score - gt_score) "
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -790,15 +751,17 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"from trulens.benchmark.benchmark_frameworks.experiments.dataset_preprocessing import (\n",
" generate_ms_marco_trec_dl_annotation_benchmark,\n",
")\n",
"\n",
"trec_doc_2022 = list(generate_ms_marco_trec_dl_annotation_benchmark(dataset_path=\"msmarco-document-v2/trec-dl-2022\", max_samples_per_bucket=150))"
"# trec_doc_2022 = list(generate_ms_marco_trec_dl_annotation_benchmark(dataset_path=\"msmarco-document-v2/trec-dl-2022\", max_samples_per_bucket=150))\n",
"\n",
"trec_passage_2022 = list(generate_ms_marco_trec_dl_annotation_benchmark(dataset_path=\"msmarco-passage-v2/trec-dl-2022\", max_samples_per_bucket=150))"
]
},
{
Expand All @@ -818,7 +781,7 @@
}
],
"source": [
"len(trec_doc_2022)"
"len(trec_passage_2022)"
]
},
{
Expand All @@ -827,7 +790,7 @@
"metadata": {},
"outputs": [],
"source": [
"trec_doc_2022_true_labels = [entry[\"expected_score\"] for entry in trec_doc_2022]"
"trec_passage_2022_true_labels = [entry[\"expected_score\"] for entry in trec_passage_2022]"
]
},
{
Expand All @@ -851,7 +814,7 @@
" visualize_expected_score_distribution,\n",
")\n",
"\n",
"visualize_expected_score_distribution(trec_doc_2022_true_labels)"
"visualize_expected_score_distribution(trec_passage_2022_true_labels)"
]
},
{
Expand Down Expand Up @@ -937,6 +900,43 @@
"from trulens.core import Feedback\n",
"\n",
"\n",
"from trulens.core import Provider\n",
"\n",
"\n",
"THRESHOLD = 0.5 # for passage retrieval annotation, we consider a score of 0.5 or above as relevant\n",
"class CustomTermFeedback(Provider):\n",
" def true_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 1 else 0.0\n",
" \n",
" def true_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 0 else 0.0\n",
"\n",
" def false_positive(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 1 and binary_gt_score == 0 else 0.0\n",
" def false_negative(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" binary_score = 1 if feedback_score >= 0.5 else 0\n",
" binary_gt_score = 1 if gt_score >= THRESHOLD else 0\n",
" return 1.0 if binary_score == 0 and binary_gt_score == 1 else 0.0\n",
" \n",
" def term_absolute_error(self, output: str) -> float:\n",
" feedback_score, gt_score = float(output.split(\";\")[0]), float(output.split(\";\")[1])\n",
" return abs(feedback_score - gt_score) \n",
" \n",
" def raw_gt_score(self, output: str) -> float:\n",
" return float(output.split(\";\")[1])\n",
" \n",
" def raw_feedback_score(self, output: str) -> float:\n",
" return float(output.split(\";\")[0])\n",
"\n",
"custom_term_feedback = CustomTermFeedback()\n",
"\n",
Expand All @@ -945,14 +945,15 @@
"f_fp = Feedback(custom_term_feedback.false_positive, name=\"False Positive\", higher_is_better=False).on_output()\n",
"f_fn = Feedback(custom_term_feedback.false_negative, name=\"False Negative\", higher_is_better=False).on_output()\n",
"f_abs_err = Feedback(custom_term_feedback.term_absolute_error, name=\"Absolute Error\", higher_is_better=False).on_output()\n",
"f_raw_gt_score = Feedback(custom_term_feedback.raw_gt_score, name=\"Raw GT Score\", higher_is_better=True).on_output()\n",
"f_raw_feedback_score = Feedback(custom_term_feedback.raw_feedback_score, name=\"Raw Feedback Score\", higher_is_better=True).on_output()\n",
"\n",
"\n",
"CUSTOM_FEEDBACK_FUNCS = [f_tp, f_tn, f_fp, f_fn, f_abs_err]\n",
"CUSTOM_FEEDBACK_FUNCS = [f_tp, f_tn, f_fp, f_fn, f_abs_err, f_raw_gt_score, f_raw_feedback_score]\n",
"\n",
"def run_experiment_for_provider(provider, dataset_df):\n",
" tru_wrapped_context_relevance_app = TruBasicApp(\n",
" trulens_context_relevance,\n",
" app_name=\"trec-dl-doc-2022-11102024\",\n",
" app_name=\"trec-dl-passage-2022\",\n",
" app_version=f\"{provider.model_engine}-context-relevance\",\n",
" feedbacks=CUSTOM_FEEDBACK_FUNCS,\n",
" )\n",
Expand All @@ -973,7 +974,7 @@
"\n",
"for provider in PROVIDERS:\n",
" print(f\"Running provider: {provider.model_engine}\")\n",
" run_experiment_for_provider(provider, trec_doc_2022)\n",
" run_experiment_for_provider(provider, trec_passage_2022)\n",
"\n",
"# with concurrent.futures.ThreadPoolExecutor() as executor:\n",
"# futures = [executor.submit(run_experiment_for_provider, provider, trec_doc_2022) for provider in PROVIDERS]\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def generate_balanced_ms_marco_hard_negatives_dataset(


def generate_ms_marco_trec_dl_annotation_benchmark(
dataset_path: str = "msmarco-document-v2/trec-dl-2022",
dataset_path: str = "msmarco-passage-v2/trec-dl-2022",
max_samples_per_bucket: int = 100,
):
dataset = ir_datasets.load(dataset_path)
Expand Down

0 comments on commit 222f201

Please sign in to comment.