Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
orionw committed Sep 11, 2024
0 parents commit aec840d
Show file tree
Hide file tree
Showing 41 changed files with 2,924 additions and 0 deletions.
114 changes: 114 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Promptriever: Retrieval models can be controlled with prompts, just like language models

Official repository for the paper [Promptriever: Retrieval models can be controlled with prompts, just like language models](todo). This repository contains the code and resources for Promptriever, which demonstrates that retrieval models can be controlled with prompts on a per-instance basis, similar to language models.

Evaluation can also be done by using the MTEB repository, see [here for examples](todo).


## Table of Contents
- [Links](#links)
- [Setup](#setup)
- [Experiments](#experiments)
- [MSMARCO](#msmarco-experiments)
- [BEIR](#beir-experiments)
- [Analysis](#analysis)
- [Training](#training)
- [Utilities](#utilities)
- [Citation](#citation)


## Links

| Binary | Description |
|:------|:-------------------------------------------------------------------------------------------------------------------------------------------|
| [promptriever-llama2-v1](https://huggingface.co/jhu-clsp/FollowIR-7B) | The promptriever dense retrieval model used in the majority of the paper, based on Llama-2 |
| [msmarco-w-instructions](https://huggingface.co/datasets/jhu-clsp/FollowIR-train) | The dataset used to train promptriever-llama2-v1, from augmenting MSMarco with instruction data and instruction-negatives. |

## Setup

To initialize your research environment:

```bash
bash setup/install_conda.sh
bash setup/install_req.sh
python setup/download_dev_sets.py
```

These steps ensure consistent software versions and datasets across all research environments.

## Experiments

### MSMARCO Experiments

Run a complete MSMARCO experiment:

```bash
bash msmarco/encode_corpus.sh <output_path> <model_name>
bash msmarco/encode_queries.sh <output_path> <model_name>
bash msmarco/search.sh <output_path>
```

### BEIR Experiments

Execute comprehensive BEIR experiments:

```bash
bash beir/run_all.sh <model_name> <output_nickname>
bash beir/run_all_prompts.sh <model_name> <output_nickname>
bash beir/search_all_prompts.sh <output_path>
```

The `beir/bm25` subfolder contains scripts for BM25 baseline experiments.

## Analysis

### Visualization

Use scripts in the `plotting` folder to generate insightful visualizations:

- `gather_results.py`: Aggregates results from different runs
- `get_sd_table.py`: Generates standard deviation tables
- `make_prompt_all_table.py`: Creates comprehensive prompt-based result tables
- `make_prompt_table_from_results.py`: Generates detailed tables for prompt effectiveness

### Error Analysis

Conduct in-depth error analysis:

```bash
python error_analysis/error_analysis.py <run1> <run2> <dataset> <output_dir>
```

Additional scripts: `error_analysis_bow.py` and `error_analysis_modeling.py`

## Training

Train or fine-tune retrieval models:

```bash
bash training/train.sh <model_args>
```

Available training scripts:
- `train_instruct_llama3_instruct.sh`
- `train_instruct_llama3.sh`
- `train_instruct_mistral_v1.sh`
- `train_instruct_mistral.sh`
- `train_instruct.sh`

## Utilities

- `utils/symlink_dev.sh` and `utils/symlink_msmarco.sh`: Optimize storage usage
- `utils/upload_to_hf_all.py` and `utils/upload_to_hf.py`: Upload models to Hugging Face Hub
- `utils/validate_all_present.py`: Validate dataset completeness
- `filtering/filter_query_doc_pairs_from_batch_gpt.py`: Implement advanced filtering using GPT model outputs

## Citation

If you found the code, data or model useful, free to cite:

```bibtex
@misc{todo}
}
```
65 changes: 65 additions & 0 deletions scripts/beir/bm25/prompt_all_bm25.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/bin/bash

# example usage:
# bash scripts/beir/bm25/prompt_all_bm25.sh

nickname=bm25

mkdir -p $nickname

datasets=(
'arguana'
'fiqa'
'nfcorpus'
'scidocs'
'scifact'
'trec-covid'
'webis-touche2020'
'quora'
'nq'
'hotpotqa'
'climate-fever'
'dbpedia-entity'
'fever'
# 'msmarco-dl19'
# 'msmarco-dl20'
# 'msmarco-dev'
'nfcorpus-dev'
# 'nq-dev'
'scifact-dev'
'fiqa-dev'
'hotpotqa-dev'
'dbpedia-entity-dev'
'quora-dev'
'fever-dev'
)


# Read in each line of the generic_prompts.csv file where each line is a prompt
# Run it on each dataset, hashing the prompt and passing that as the fourth argument
while IFS= read -r prompt
do
prompt_hash=$(echo -n "$prompt" | md5sum | awk '{print $1}')
for dataset in "${datasets[@]}"; do
mkdir -p "$nickname/$dataset"
if [ -f "$nickname/$dataset/${dataset}_${prompt_hash}.trec" ]; then
echo "Skipping $dataset because of existing file $nickname/$dataset/$dataaset_$prompt_hash.trec"
continue
fi
echo "Running prompt on dataset: $dataset"
echo "Prompt: '$prompt'"
python scripts/beir/bm25/run_bm25s.py --dataset_name "$dataset" --prompt "$prompt" --top_k 1000 --output_dir "$nickname/$dataset" --prompt_hash "$prompt_hash"
done
done < generic_prompts.csv


# also run one without a prompt for each dataset
for dataset in "${datasets[@]}"; do
echo "Running without prompt on dataset: $dataset"
if [ -f "$nickname/$dataset/$dataset.trec" ]; then
echo "Skipping $dataset because of existing file $nickname/$dataset/$dataset.trec"
continue
fi
echo "Running without prompt on dataset: $dataset"
python scripts/beir/bm25/run_bm25s.py --dataset_name $dataset --top_k 1000 --output_dir "$nickname/$dataset"
done
88 changes: 88 additions & 0 deletions scripts/beir/bm25/run_bm25s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import argparse
import logging
import Stemmer
import bm25s.hf
from datasets import load_dataset
from tqdm import tqdm

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def load_queries(dataset_name):
if "msmarco-" in dataset_name:
logging.info(f"Loading MS MARCO queries for dataset: {dataset_name}")
dataset = load_dataset(f"tevatron/msmarco-passage", split=dataset_name.split("-")[-1], trust_remote_code=True)
return {row['query_id']: row['query'] for row in dataset}
else:
logging.info(f"Loading queries for dataset: {dataset_name}")
dataset = load_dataset(f"orionweller/beir", dataset_name, trust_remote_code=True)["test"]
return {row['query_id']: row['query'] for row in dataset}

def main(args):
logging.info(f"Starting BM25S search for dataset: {args.dataset_name}")

# Load the BM25 index from Hugging Face Hub
if "msmarco-" in args.dataset_name:
cur_dataset_name = "msmarco"
elif "-dev" in args.dataset_name:
cur_dataset_name = args.dataset_name.replace("-dev", "")
else:
cur_dataset_name = args.dataset_name
index_name = f"xhluca/bm25s-{cur_dataset_name}-index"
logging.info(f"Loading BM25 index from: {index_name}")
retriever = bm25s.hf.BM25HF.load_from_hub(
index_name, load_corpus=True, mmap=True
)
logging.info("BM25 index loaded successfully")

# Load queries
queries = load_queries(args.dataset_name)
logging.info(f"Loaded {len(queries)} queries")

# Initialize stemmer
stemmer = Stemmer.Stemmer("english")
logging.info("Initialized English stemmer")

# Prepare output file
os.makedirs(args.output_dir, exist_ok=True)
basename = f"{args.dataset_name}_{args.prompt_hash}.trec" if args.prompt_hash else f"{args.dataset_name}.trec"
output_file = os.path.join(args.output_dir, basename)
logging.info(f"Results will be saved to: {output_file}")

with open(output_file, 'w') as f:
for query_id, query in tqdm(queries.items(), desc="Processing queries"):
# Append prompt if provided
if args.prompt.strip != "":
query += f" {args.prompt}"

# Tokenize the query
query_tokenized = bm25s.tokenize([query], stemmer=stemmer)

# Retrieve the top-k results
# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
results, scores = retriever.retrieve(query_tokenized, k=args.top_k)
# since there is only one query, we can just take the first element
results = results[0]
scores = scores[0]

# Write results in TREC format
for rank, (doc, score) in enumerate(zip(results, scores)):
f.write(f"{query_id} Q0 {doc['id']} {rank+1} {score} bm25s\n")

logging.info(f"Search completed. Results saved to {output_file}")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="BM25S Search Script")
parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name (e.g., webis-touche2020)")
parser.add_argument("--prompt", type=str, default="", help="Prompt to append to each query")
parser.add_argument("--prompt_hash", type=str, default="", help="Prompt hash to append to each query")
parser.add_argument("--top_k", type=int, default=1000, help="Number of top results to retrieve")
parser.add_argument("--output_dir", type=str, default="results", help="Output directory for results")
args = parser.parse_args()
print(f"Arguments: {args}")

main(args)

# example usage:
# python run_bm25s.py --dataset_name webis-touche2020 --prompt "Retrieve relevant documents for the given query:" --top_k 1000 --output_dir results
Loading

0 comments on commit aec840d

Please sign in to comment.