-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit aec840d
Showing
41 changed files
with
2,924 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Promptriever: Retrieval models can be controlled with prompts, just like language models | ||
|
||
Official repository for the paper [Promptriever: Retrieval models can be controlled with prompts, just like language models](todo). This repository contains the code and resources for Promptriever, which demonstrates that retrieval models can be controlled with prompts on a per-instance basis, similar to language models. | ||
|
||
Evaluation can also be done by using the MTEB repository, see [here for examples](todo). | ||
|
||
|
||
## Table of Contents | ||
- [Links](#links) | ||
- [Setup](#setup) | ||
- [Experiments](#experiments) | ||
- [MSMARCO](#msmarco-experiments) | ||
- [BEIR](#beir-experiments) | ||
- [Analysis](#analysis) | ||
- [Training](#training) | ||
- [Utilities](#utilities) | ||
- [Citation](#citation) | ||
|
||
|
||
## Links | ||
|
||
| Binary | Description | | ||
|:------|:-------------------------------------------------------------------------------------------------------------------------------------------| | ||
| [promptriever-llama2-v1](https://huggingface.co/jhu-clsp/FollowIR-7B) | The promptriever dense retrieval model used in the majority of the paper, based on Llama-2 | | ||
| [msmarco-w-instructions](https://huggingface.co/datasets/jhu-clsp/FollowIR-train) | The dataset used to train promptriever-llama2-v1, from augmenting MSMarco with instruction data and instruction-negatives. | | ||
|
||
## Setup | ||
|
||
To initialize your research environment: | ||
|
||
```bash | ||
bash setup/install_conda.sh | ||
bash setup/install_req.sh | ||
python setup/download_dev_sets.py | ||
``` | ||
|
||
These steps ensure consistent software versions and datasets across all research environments. | ||
|
||
## Experiments | ||
|
||
### MSMARCO Experiments | ||
|
||
Run a complete MSMARCO experiment: | ||
|
||
```bash | ||
bash msmarco/encode_corpus.sh <output_path> <model_name> | ||
bash msmarco/encode_queries.sh <output_path> <model_name> | ||
bash msmarco/search.sh <output_path> | ||
``` | ||
|
||
### BEIR Experiments | ||
|
||
Execute comprehensive BEIR experiments: | ||
|
||
```bash | ||
bash beir/run_all.sh <model_name> <output_nickname> | ||
bash beir/run_all_prompts.sh <model_name> <output_nickname> | ||
bash beir/search_all_prompts.sh <output_path> | ||
``` | ||
|
||
The `beir/bm25` subfolder contains scripts for BM25 baseline experiments. | ||
|
||
## Analysis | ||
|
||
### Visualization | ||
|
||
Use scripts in the `plotting` folder to generate insightful visualizations: | ||
|
||
- `gather_results.py`: Aggregates results from different runs | ||
- `get_sd_table.py`: Generates standard deviation tables | ||
- `make_prompt_all_table.py`: Creates comprehensive prompt-based result tables | ||
- `make_prompt_table_from_results.py`: Generates detailed tables for prompt effectiveness | ||
|
||
### Error Analysis | ||
|
||
Conduct in-depth error analysis: | ||
|
||
```bash | ||
python error_analysis/error_analysis.py <run1> <run2> <dataset> <output_dir> | ||
``` | ||
|
||
Additional scripts: `error_analysis_bow.py` and `error_analysis_modeling.py` | ||
|
||
## Training | ||
|
||
Train or fine-tune retrieval models: | ||
|
||
```bash | ||
bash training/train.sh <model_args> | ||
``` | ||
|
||
Available training scripts: | ||
- `train_instruct_llama3_instruct.sh` | ||
- `train_instruct_llama3.sh` | ||
- `train_instruct_mistral_v1.sh` | ||
- `train_instruct_mistral.sh` | ||
- `train_instruct.sh` | ||
|
||
## Utilities | ||
|
||
- `utils/symlink_dev.sh` and `utils/symlink_msmarco.sh`: Optimize storage usage | ||
- `utils/upload_to_hf_all.py` and `utils/upload_to_hf.py`: Upload models to Hugging Face Hub | ||
- `utils/validate_all_present.py`: Validate dataset completeness | ||
- `filtering/filter_query_doc_pairs_from_batch_gpt.py`: Implement advanced filtering using GPT model outputs | ||
|
||
## Citation | ||
|
||
If you found the code, data or model useful, free to cite: | ||
|
||
```bibtex | ||
@misc{todo} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#!/bin/bash | ||
|
||
# example usage: | ||
# bash scripts/beir/bm25/prompt_all_bm25.sh | ||
|
||
nickname=bm25 | ||
|
||
mkdir -p $nickname | ||
|
||
datasets=( | ||
'arguana' | ||
'fiqa' | ||
'nfcorpus' | ||
'scidocs' | ||
'scifact' | ||
'trec-covid' | ||
'webis-touche2020' | ||
'quora' | ||
'nq' | ||
'hotpotqa' | ||
'climate-fever' | ||
'dbpedia-entity' | ||
'fever' | ||
# 'msmarco-dl19' | ||
# 'msmarco-dl20' | ||
# 'msmarco-dev' | ||
'nfcorpus-dev' | ||
# 'nq-dev' | ||
'scifact-dev' | ||
'fiqa-dev' | ||
'hotpotqa-dev' | ||
'dbpedia-entity-dev' | ||
'quora-dev' | ||
'fever-dev' | ||
) | ||
|
||
|
||
# Read in each line of the generic_prompts.csv file where each line is a prompt | ||
# Run it on each dataset, hashing the prompt and passing that as the fourth argument | ||
while IFS= read -r prompt | ||
do | ||
prompt_hash=$(echo -n "$prompt" | md5sum | awk '{print $1}') | ||
for dataset in "${datasets[@]}"; do | ||
mkdir -p "$nickname/$dataset" | ||
if [ -f "$nickname/$dataset/${dataset}_${prompt_hash}.trec" ]; then | ||
echo "Skipping $dataset because of existing file $nickname/$dataset/$dataaset_$prompt_hash.trec" | ||
continue | ||
fi | ||
echo "Running prompt on dataset: $dataset" | ||
echo "Prompt: '$prompt'" | ||
python scripts/beir/bm25/run_bm25s.py --dataset_name "$dataset" --prompt "$prompt" --top_k 1000 --output_dir "$nickname/$dataset" --prompt_hash "$prompt_hash" | ||
done | ||
done < generic_prompts.csv | ||
|
||
|
||
# also run one without a prompt for each dataset | ||
for dataset in "${datasets[@]}"; do | ||
echo "Running without prompt on dataset: $dataset" | ||
if [ -f "$nickname/$dataset/$dataset.trec" ]; then | ||
echo "Skipping $dataset because of existing file $nickname/$dataset/$dataset.trec" | ||
continue | ||
fi | ||
echo "Running without prompt on dataset: $dataset" | ||
python scripts/beir/bm25/run_bm25s.py --dataset_name $dataset --top_k 1000 --output_dir "$nickname/$dataset" | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import os | ||
import argparse | ||
import logging | ||
import Stemmer | ||
import bm25s.hf | ||
from datasets import load_dataset | ||
from tqdm import tqdm | ||
|
||
# Set up logging | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | ||
|
||
def load_queries(dataset_name): | ||
if "msmarco-" in dataset_name: | ||
logging.info(f"Loading MS MARCO queries for dataset: {dataset_name}") | ||
dataset = load_dataset(f"tevatron/msmarco-passage", split=dataset_name.split("-")[-1], trust_remote_code=True) | ||
return {row['query_id']: row['query'] for row in dataset} | ||
else: | ||
logging.info(f"Loading queries for dataset: {dataset_name}") | ||
dataset = load_dataset(f"orionweller/beir", dataset_name, trust_remote_code=True)["test"] | ||
return {row['query_id']: row['query'] for row in dataset} | ||
|
||
def main(args): | ||
logging.info(f"Starting BM25S search for dataset: {args.dataset_name}") | ||
|
||
# Load the BM25 index from Hugging Face Hub | ||
if "msmarco-" in args.dataset_name: | ||
cur_dataset_name = "msmarco" | ||
elif "-dev" in args.dataset_name: | ||
cur_dataset_name = args.dataset_name.replace("-dev", "") | ||
else: | ||
cur_dataset_name = args.dataset_name | ||
index_name = f"xhluca/bm25s-{cur_dataset_name}-index" | ||
logging.info(f"Loading BM25 index from: {index_name}") | ||
retriever = bm25s.hf.BM25HF.load_from_hub( | ||
index_name, load_corpus=True, mmap=True | ||
) | ||
logging.info("BM25 index loaded successfully") | ||
|
||
# Load queries | ||
queries = load_queries(args.dataset_name) | ||
logging.info(f"Loaded {len(queries)} queries") | ||
|
||
# Initialize stemmer | ||
stemmer = Stemmer.Stemmer("english") | ||
logging.info("Initialized English stemmer") | ||
|
||
# Prepare output file | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
basename = f"{args.dataset_name}_{args.prompt_hash}.trec" if args.prompt_hash else f"{args.dataset_name}.trec" | ||
output_file = os.path.join(args.output_dir, basename) | ||
logging.info(f"Results will be saved to: {output_file}") | ||
|
||
with open(output_file, 'w') as f: | ||
for query_id, query in tqdm(queries.items(), desc="Processing queries"): | ||
# Append prompt if provided | ||
if args.prompt.strip != "": | ||
query += f" {args.prompt}" | ||
|
||
# Tokenize the query | ||
query_tokenized = bm25s.tokenize([query], stemmer=stemmer) | ||
|
||
# Retrieve the top-k results | ||
# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k) | ||
results, scores = retriever.retrieve(query_tokenized, k=args.top_k) | ||
# since there is only one query, we can just take the first element | ||
results = results[0] | ||
scores = scores[0] | ||
|
||
# Write results in TREC format | ||
for rank, (doc, score) in enumerate(zip(results, scores)): | ||
f.write(f"{query_id} Q0 {doc['id']} {rank+1} {score} bm25s\n") | ||
|
||
logging.info(f"Search completed. Results saved to {output_file}") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="BM25S Search Script") | ||
parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name (e.g., webis-touche2020)") | ||
parser.add_argument("--prompt", type=str, default="", help="Prompt to append to each query") | ||
parser.add_argument("--prompt_hash", type=str, default="", help="Prompt hash to append to each query") | ||
parser.add_argument("--top_k", type=int, default=1000, help="Number of top results to retrieve") | ||
parser.add_argument("--output_dir", type=str, default="results", help="Output directory for results") | ||
args = parser.parse_args() | ||
print(f"Arguments: {args}") | ||
|
||
main(args) | ||
|
||
# example usage: | ||
# python run_bm25s.py --dataset_name webis-touche2020 --prompt "Retrieve relevant documents for the given query:" --top_k 1000 --output_dir results |
Oops, something went wrong.