Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TREC DL and LLM AggreFact experiments for relevance benchmark + prompts comparisons and groundedness vs Bespoke Minicheck 7B #1660

Merged
merged 23 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
notebook updates
  • Loading branch information
sfc-gh-dhuang committed Dec 3, 2024
commit fbe2ae4804873ebae191ccdfba24dc30e38d81de
Original file line number Diff line number Diff line change
Expand Up @@ -364,53 +364,231 @@ def generate_balanced_ms_marco_hard_negatives_dataset(
})


def generate_ms_marco_trec_dl_annotation_benchmark(
dataset_path: str = "msmarco-passage-v2/trec-dl-2022",
max_samples_per_bucket: int = 100,
def generate_trec_dl_doc_benchmark(
max_samples_per_bucket: int = 3,
dataset_path: str = "msmarco-document-v2/trec-dl-2021/judged",
):
dataset = ir_datasets.load(dataset_path)
sample_counts = {0: 0, 1: 0, 2: 0, 3: 0}
queries = {}
qrels = defaultdict(dict)
docs_store = None

# Load queries and qrels
queries.update({q.query_id: q for q in dataset.queries_iter()})
for qid, docs in dataset.qrels_dict().items():
qrels[qid].update(docs)

# Pre-build a dictionary of qrels by query_id and doc_id for fast lookup
qrels_by_query = defaultdict(list)
for qrel in dataset.qrels_iter():
qrels_by_query[qrel.query_id].append((qrel.doc_id, qrel.relevance))

# Pre-build a dictionary of documents by doc_id for quick access
docs_dict = {doc.doc_id: doc for doc in dataset.docs_iter()[: 1 / 50]}

# Generate samples
for query in dataset.queries_iter():
if query.query_id in qrels_by_query:
for doc_id, relevance in qrels_by_query[query.query_id]:
if sample_counts[relevance] < max_samples_per_bucket:
doc = docs_dict.get(doc_id)
if doc:
doc_content = (
doc.body
if hasattr(doc, "body")
else doc.text
if hasattr(doc, "text")
else None
)
if doc_content is None:
continue

yield {
"query_id": query.query_id,
"query": query.text,
"doc_id": doc_id,
"expected_response": doc_content,
"expected_score": relevance / 3,
}
sample_counts[relevance] += 1

# Stop if all sample buckets are filled
if all(
count >= max_samples_per_bucket
for count in sample_counts.values()
):
return
# Get docs_store
if docs_store is None:
docs_store = dataset.docs_store()

# Generate samples for each query
for query_id, query in queries.items():
if query_id not in qrels:
print(f"Query ID {query_id} not found in qrels, skipping...")
continue # Skip queries without relevance judgments

# Initialize sample counts per relevance score for the current query
sample_counts = {0: 0, 1: 0, 2: 0, 3: 0}

# Get all document IDs and relevance scores for the query
for doc_id, relevance in qrels_by_query[query_id]:
if sample_counts[relevance] < max_samples_per_bucket:
# Retrieve document content
doc = docs_store.get(doc_id)
if doc:
doc_content = (
doc.body
if hasattr(doc, "body")
else doc.text
if hasattr(doc, "text")
else None
)
if doc_content is None:
continue

# Yield the sample
yield {
"query_id": str(query.query_id),
"query": query.text,
"doc_id": doc_id,
"expected_response": doc_content,
"expected_score": relevance / 3, # Normalize to [0, 1]
}

# Update the sample count for this relevance score
sample_counts[relevance] += 1

# Stop if all buckets for this query are filled
if all(
count >= max_samples_per_bucket
for count in sample_counts.values()
):
break


def generate_trec_dl_benchmark(
max_samples_per_query_per_score: int = 3,
dataset_path: str = "msmarco-passage-v2/trec-dl-2021/judged",
):
# Combine queries and qrels from multiple datasets
queries = {}
qrels = defaultdict(dict)
docs_store = None

dataset = ir_datasets.load(dataset_path)
# Merge queries
queries.update({q.query_id: q for q in dataset.queries_iter()})
# Merge qrels
for qid, docs in dataset.qrels_dict().items():
qrels[qid].update(docs)
# Get docs_store
if docs_store is None:
docs_store = dataset.docs_store()

print("Total number of queries:", len(queries))
print("Total number of qrels:", len(qrels))

# Sampling
for query_id, query in queries.items():
if query_id not in qrels:
print("query_id not found in qrels")
continue # Skip queries without relevance judgments

# Get documents by relevance scores
relevant_docs = defaultdict(list)
for doc_id, score in qrels[query_id].items():
relevant_docs[score].append(doc_id)

# Determine scoreddocs intervals for this query
scored_docs = [
scored_doc
for scored_doc in ir_datasets.load(dataset_path).scoreddocs_iter()
if scored_doc.query_id == query_id
]
if not scored_docs:
continue

min_score = min(scored_doc.score for scored_doc in scored_docs)
max_score = max(scored_doc.score for scored_doc in scored_docs)
interval_size = (max_score - min_score) / 4
intervals = [
(min_score + i * interval_size, min_score + (i + 1) * interval_size)
for i in range(4)
]

# Initialize sampling counts
sampled_docs = []

# Use scoreddocs for all scores (0, 1, 2, and 3)
for score in [0, 1, 2, 3]:
if score in relevant_docs:
# Get ranked documents using scoreddocs
ranked_docs = []
for scored_doc in scored_docs:
if scored_doc.doc_id in relevant_docs[score]:
ranked_docs.append((
scored_doc.doc_id,
scored_doc.score,
))

# Filter documents based on interval alignment (-1, 0, +1)
allowed_intervals = [
intervals[max(0, score - 1)],
intervals[score],
intervals[min(3, score + 1)],
]
interval_docs = [
(doc_id, doc_score)
for doc_id, doc_score in ranked_docs
if any(
low <= doc_score <= high
for low, high in allowed_intervals
)
]

# Sort by score (descending) and select top documents
interval_docs.sort(key=lambda x: x[1], reverse=True)
top_docs = [
doc_id
for doc_id, _ in interval_docs[
:max_samples_per_query_per_score
]
]

# Add to sampled documents
sampled_docs.extend(top_docs)

doc_text_seen = set() # deduplication of identical passages
# Yield the sampled data
for doc_id in sampled_docs:
doc = docs_store.get(doc_id)
if doc and doc.text not in doc_text_seen:
doc_text_seen.add(doc.text)
yield {
"query_id": query_id,
"query": query.text,
"doc_id": doc_id,
"expected_response": doc.text
if hasattr(doc, "text")
else doc.body,
"expected_score": qrels[query_id][doc_id]
/ 3, # Normalize to [0, 1]
}


# def generate_ms_marco_trec_dl_annotation_benchmark(
# dataset_path: str = "msmarco-passage-v2/trec-dl-2022",
# max_samples_per_bucket: int = 100,
# ):
# dataset = ir_datasets.load(dataset_path)
# sample_counts = {0: 0, 1: 0, 2: 0, 3: 0}

# # Pre-build a dictionary of qrels by query_id and doc_id for fast lookup
# qrels_by_query = defaultdict(list)
# for qrel in dataset.qrels_iter():
# qrels_by_query[qrel.query_id].append((qrel.doc_id, qrel.relevance))

# # Pre-build a dictionary of documents by doc_id for quick access
# docs_dict = {doc.doc_id: doc for doc in dataset.docs_iter()[: 1 / 50]}

# # Generate samples
# for query in dataset.queries_iter():
# if query.query_id in qrels_by_query:
# for doc_id, relevance in qrels_by_query[query.query_id]:
# if sample_counts[relevance] < max_samples_per_bucket:
# doc = docs_dict.get(doc_id)
# if doc:
# doc_content = (
# doc.body
# if hasattr(doc, "body")
# else doc.text
# if hasattr(doc, "text")
# else None
# )
# if doc_content is None:
# continue

# yield {
# "query_id": query.query_id,
# "query": query.text,
# "doc_id": doc_id,
# "expected_response": doc_content,
# "expected_score": relevance / 3,
# }
# sample_counts[relevance] += 1

# # Stop if all sample buckets are filled
# if all(
# count >= max_samples_per_bucket
# for count in sample_counts.values()
# ):
# return


def write_results(
Expand Down
Loading
Loading