From aec840dc1995f2e96fe140ff363acb0702bddf1f Mon Sep 17 00:00:00 2001 From: Orion Weller Date: Wed, 11 Sep 2024 17:18:02 +0000 Subject: [PATCH] init --- README.md | 114 ++++++++ scripts/beir/bm25/prompt_all_bm25.sh | 65 +++++ scripts/beir/bm25/run_bm25s.py | 88 ++++++ scripts/beir/bm25/search_all_bm25.sh | 138 ++++++++++ scripts/beir/clear_all_directories.sh | 9 + scripts/beir/clear_directory.sh | 7 + scripts/beir/encode_beir_corpus.sh | 75 ++++++ scripts/beir/encode_beir_queries.sh | 83 ++++++ scripts/beir/force_redownload_all.py | 5 + scripts/beir/matrix_of_prompts.sh | 20 ++ scripts/beir/run_all.sh | 46 ++++ scripts/beir/run_all_prompts.sh | 85 ++++++ scripts/beir/search_all_prompts.sh | 160 +++++++++++ scripts/beir/search_beir.sh | 23 ++ scripts/error_analysis/error_analysis.py | 212 +++++++++++++++ scripts/error_analysis/error_analysis_bow.py | 145 ++++++++++ .../error_analysis/error_analysis_modeling.py | 100 +++++++ .../filter_query_doc_pairs_from_batch_gpt.py | 136 ++++++++++ scripts/msmarco/encode_corpus.sh | 41 +++ scripts/msmarco/encode_queries.sh | 36 +++ scripts/msmarco/search.sh | 33 +++ scripts/plotting/gather_results.py | 86 ++++++ scripts/plotting/get_sd_table.py | 119 +++++++++ scripts/plotting/make_prompt_all_table.py | 73 +++++ .../make_prompt_table_from_results.py | 251 ++++++++++++++++++ scripts/setup/download_dev_sets.py | 99 +++++++ scripts/setup/install_conda.sh | 9 + scripts/setup/install_req.sh | 17 ++ scripts/tevatron/hn_mining.py | 112 ++++++++ scripts/tevatron/reduce_results.py | 28 ++ scripts/training/train.sh | 32 +++ scripts/training/train_instruct.sh | 38 +++ scripts/training/train_instruct_llama3.sh | 38 +++ .../train_instruct_llama3_instruct.sh | 38 +++ scripts/training/train_instruct_mistral.sh | 38 +++ scripts/training/train_instruct_mistral_v1.sh | 38 +++ scripts/utils/symlink_dev.sh | 21 ++ scripts/utils/symlink_msmarco.sh | 34 +++ scripts/utils/upload_to_hf.py | 51 ++++ scripts/utils/upload_to_hf_all.py | 78 ++++++ scripts/utils/validate_all_present.py | 103 +++++++ 41 files changed, 2924 insertions(+) create mode 100644 README.md create mode 100644 scripts/beir/bm25/prompt_all_bm25.sh create mode 100644 scripts/beir/bm25/run_bm25s.py create mode 100644 scripts/beir/bm25/search_all_bm25.sh create mode 100644 scripts/beir/clear_all_directories.sh create mode 100644 scripts/beir/clear_directory.sh create mode 100644 scripts/beir/encode_beir_corpus.sh create mode 100644 scripts/beir/encode_beir_queries.sh create mode 100644 scripts/beir/force_redownload_all.py create mode 100644 scripts/beir/matrix_of_prompts.sh create mode 100644 scripts/beir/run_all.sh create mode 100644 scripts/beir/run_all_prompts.sh create mode 100644 scripts/beir/search_all_prompts.sh create mode 100644 scripts/beir/search_beir.sh create mode 100644 scripts/error_analysis/error_analysis.py create mode 100644 scripts/error_analysis/error_analysis_bow.py create mode 100644 scripts/error_analysis/error_analysis_modeling.py create mode 100644 scripts/filtering/filter_query_doc_pairs_from_batch_gpt.py create mode 100644 scripts/msmarco/encode_corpus.sh create mode 100644 scripts/msmarco/encode_queries.sh create mode 100644 scripts/msmarco/search.sh create mode 100644 scripts/plotting/gather_results.py create mode 100644 scripts/plotting/get_sd_table.py create mode 100644 scripts/plotting/make_prompt_all_table.py create mode 100644 scripts/plotting/make_prompt_table_from_results.py create mode 100644 scripts/setup/download_dev_sets.py create mode 100644 scripts/setup/install_conda.sh create mode 100644 scripts/setup/install_req.sh create mode 100644 scripts/tevatron/hn_mining.py create mode 100644 scripts/tevatron/reduce_results.py create mode 100644 scripts/training/train.sh create mode 100644 scripts/training/train_instruct.sh create mode 100644 scripts/training/train_instruct_llama3.sh create mode 100644 scripts/training/train_instruct_llama3_instruct.sh create mode 100644 scripts/training/train_instruct_mistral.sh create mode 100644 scripts/training/train_instruct_mistral_v1.sh create mode 100644 scripts/utils/symlink_dev.sh create mode 100644 scripts/utils/symlink_msmarco.sh create mode 100644 scripts/utils/upload_to_hf.py create mode 100644 scripts/utils/upload_to_hf_all.py create mode 100644 scripts/utils/validate_all_present.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..de5becb --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +# Promptriever: Retrieval models can be controlled with prompts, just like language models + +Official repository for the paper [Promptriever: Retrieval models can be controlled with prompts, just like language models](todo). This repository contains the code and resources for Promptriever, which demonstrates that retrieval models can be controlled with prompts on a per-instance basis, similar to language models. + +Evaluation can also be done by using the MTEB repository, see [here for examples](todo). + + +## Table of Contents +- [Links](#links) +- [Setup](#setup) +- [Experiments](#experiments) + - [MSMARCO](#msmarco-experiments) + - [BEIR](#beir-experiments) +- [Analysis](#analysis) +- [Training](#training) +- [Utilities](#utilities) +- [Citation](#citation) + + +## Links + +| Binary | Description | +|:------|:-------------------------------------------------------------------------------------------------------------------------------------------| +| [promptriever-llama2-v1](https://huggingface.co/jhu-clsp/FollowIR-7B) | The promptriever dense retrieval model used in the majority of the paper, based on Llama-2 | +| [msmarco-w-instructions](https://huggingface.co/datasets/jhu-clsp/FollowIR-train) | The dataset used to train promptriever-llama2-v1, from augmenting MSMarco with instruction data and instruction-negatives. | + + +## Setup + +To initialize your research environment: + +```bash +bash setup/install_conda.sh +bash setup/install_req.sh +python setup/download_dev_sets.py +``` + +These steps ensure consistent software versions and datasets across all research environments. + +## Experiments + +### MSMARCO Experiments + +Run a complete MSMARCO experiment: + +```bash +bash msmarco/encode_corpus.sh +bash msmarco/encode_queries.sh +bash msmarco/search.sh +``` + +### BEIR Experiments + +Execute comprehensive BEIR experiments: + +```bash +bash beir/run_all.sh +bash beir/run_all_prompts.sh +bash beir/search_all_prompts.sh +``` + +The `beir/bm25` subfolder contains scripts for BM25 baseline experiments. + +## Analysis + +### Visualization + +Use scripts in the `plotting` folder to generate insightful visualizations: + +- `gather_results.py`: Aggregates results from different runs +- `get_sd_table.py`: Generates standard deviation tables +- `make_prompt_all_table.py`: Creates comprehensive prompt-based result tables +- `make_prompt_table_from_results.py`: Generates detailed tables for prompt effectiveness + +### Error Analysis + +Conduct in-depth error analysis: + +```bash +python error_analysis/error_analysis.py +``` + +Additional scripts: `error_analysis_bow.py` and `error_analysis_modeling.py` + +## Training + +Train or fine-tune retrieval models: + +```bash +bash training/train.sh +``` + +Available training scripts: +- `train_instruct_llama3_instruct.sh` +- `train_instruct_llama3.sh` +- `train_instruct_mistral_v1.sh` +- `train_instruct_mistral.sh` +- `train_instruct.sh` + +## Utilities + +- `utils/symlink_dev.sh` and `utils/symlink_msmarco.sh`: Optimize storage usage +- `utils/upload_to_hf_all.py` and `utils/upload_to_hf.py`: Upload models to Hugging Face Hub +- `utils/validate_all_present.py`: Validate dataset completeness +- `filtering/filter_query_doc_pairs_from_batch_gpt.py`: Implement advanced filtering using GPT model outputs + +## Citation + +If you found the code, data or model useful, free to cite: + +```bibtex +@misc{todo} +} +``` \ No newline at end of file diff --git a/scripts/beir/bm25/prompt_all_bm25.sh b/scripts/beir/bm25/prompt_all_bm25.sh new file mode 100644 index 0000000..e57f2d7 --- /dev/null +++ b/scripts/beir/bm25/prompt_all_bm25.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# example usage: +# bash scripts/beir/bm25/prompt_all_bm25.sh + +nickname=bm25 + +mkdir -p $nickname + +datasets=( + 'arguana' + 'fiqa' + 'nfcorpus' + 'scidocs' + 'scifact' + 'trec-covid' + 'webis-touche2020' + 'quora' + 'nq' + 'hotpotqa' + 'climate-fever' + 'dbpedia-entity' + 'fever' + # 'msmarco-dl19' + # 'msmarco-dl20' + # 'msmarco-dev' + 'nfcorpus-dev' + # 'nq-dev' + 'scifact-dev' + 'fiqa-dev' + 'hotpotqa-dev' + 'dbpedia-entity-dev' + 'quora-dev' + 'fever-dev' +) + + +# Read in each line of the generic_prompts.csv file where each line is a prompt +# Run it on each dataset, hashing the prompt and passing that as the fourth argument +while IFS= read -r prompt +do + prompt_hash=$(echo -n "$prompt" | md5sum | awk '{print $1}') + for dataset in "${datasets[@]}"; do + mkdir -p "$nickname/$dataset" + if [ -f "$nickname/$dataset/${dataset}_${prompt_hash}.trec" ]; then + echo "Skipping $dataset because of existing file $nickname/$dataset/$dataaset_$prompt_hash.trec" + continue + fi + echo "Running prompt on dataset: $dataset" + echo "Prompt: '$prompt'" + python scripts/beir/bm25/run_bm25s.py --dataset_name "$dataset" --prompt "$prompt" --top_k 1000 --output_dir "$nickname/$dataset" --prompt_hash "$prompt_hash" + done +done < generic_prompts.csv + + +# also run one without a prompt for each dataset +for dataset in "${datasets[@]}"; do +echo "Running without prompt on dataset: $dataset" + if [ -f "$nickname/$dataset/$dataset.trec" ]; then + echo "Skipping $dataset because of existing file $nickname/$dataset/$dataset.trec" + continue + fi + echo "Running without prompt on dataset: $dataset" + python scripts/beir/bm25/run_bm25s.py --dataset_name $dataset --top_k 1000 --output_dir "$nickname/$dataset" +done diff --git a/scripts/beir/bm25/run_bm25s.py b/scripts/beir/bm25/run_bm25s.py new file mode 100644 index 0000000..3ba4d3c --- /dev/null +++ b/scripts/beir/bm25/run_bm25s.py @@ -0,0 +1,88 @@ +import os +import argparse +import logging +import Stemmer +import bm25s.hf +from datasets import load_dataset +from tqdm import tqdm + +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def load_queries(dataset_name): + if "msmarco-" in dataset_name: + logging.info(f"Loading MS MARCO queries for dataset: {dataset_name}") + dataset = load_dataset(f"tevatron/msmarco-passage", split=dataset_name.split("-")[-1], trust_remote_code=True) + return {row['query_id']: row['query'] for row in dataset} + else: + logging.info(f"Loading queries for dataset: {dataset_name}") + dataset = load_dataset(f"orionweller/beir", dataset_name, trust_remote_code=True)["test"] + return {row['query_id']: row['query'] for row in dataset} + +def main(args): + logging.info(f"Starting BM25S search for dataset: {args.dataset_name}") + + # Load the BM25 index from Hugging Face Hub + if "msmarco-" in args.dataset_name: + cur_dataset_name = "msmarco" + elif "-dev" in args.dataset_name: + cur_dataset_name = args.dataset_name.replace("-dev", "") + else: + cur_dataset_name = args.dataset_name + index_name = f"xhluca/bm25s-{cur_dataset_name}-index" + logging.info(f"Loading BM25 index from: {index_name}") + retriever = bm25s.hf.BM25HF.load_from_hub( + index_name, load_corpus=True, mmap=True + ) + logging.info("BM25 index loaded successfully") + + # Load queries + queries = load_queries(args.dataset_name) + logging.info(f"Loaded {len(queries)} queries") + + # Initialize stemmer + stemmer = Stemmer.Stemmer("english") + logging.info("Initialized English stemmer") + + # Prepare output file + os.makedirs(args.output_dir, exist_ok=True) + basename = f"{args.dataset_name}_{args.prompt_hash}.trec" if args.prompt_hash else f"{args.dataset_name}.trec" + output_file = os.path.join(args.output_dir, basename) + logging.info(f"Results will be saved to: {output_file}") + + with open(output_file, 'w') as f: + for query_id, query in tqdm(queries.items(), desc="Processing queries"): + # Append prompt if provided + if args.prompt.strip != "": + query += f" {args.prompt}" + + # Tokenize the query + query_tokenized = bm25s.tokenize([query], stemmer=stemmer) + + # Retrieve the top-k results + # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k) + results, scores = retriever.retrieve(query_tokenized, k=args.top_k) + # since there is only one query, we can just take the first element + results = results[0] + scores = scores[0] + + # Write results in TREC format + for rank, (doc, score) in enumerate(zip(results, scores)): + f.write(f"{query_id} Q0 {doc['id']} {rank+1} {score} bm25s\n") + + logging.info(f"Search completed. Results saved to {output_file}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="BM25S Search Script") + parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name (e.g., webis-touche2020)") + parser.add_argument("--prompt", type=str, default="", help="Prompt to append to each query") + parser.add_argument("--prompt_hash", type=str, default="", help="Prompt hash to append to each query") + parser.add_argument("--top_k", type=int, default=1000, help="Number of top results to retrieve") + parser.add_argument("--output_dir", type=str, default="results", help="Output directory for results") + args = parser.parse_args() + print(f"Arguments: {args}") + + main(args) + + # example usage: + # python run_bm25s.py --dataset_name webis-touche2020 --prompt "Retrieve relevant documents for the given query:" --top_k 1000 --output_dir results \ No newline at end of file diff --git a/scripts/beir/bm25/search_all_bm25.sh b/scripts/beir/bm25/search_all_bm25.sh new file mode 100644 index 0000000..dae6e40 --- /dev/null +++ b/scripts/beir/bm25/search_all_bm25.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# example: bash scripts/beir/bm25/search_all_bm25.sh bm25 + +results_path=$1 + +datasets=( + 'nfcorpus-dev' + 'arguana' + 'fiqa' + 'nfcorpus' + 'scidocs' + 'scifact' + 'trec-covid' + 'webis-touche2020' + 'quora' + 'nq' + 'hotpotqa' + 'climate-fever' + 'dbpedia-entity' + 'fever' + # 'msmarco-dl19' + # 'msmarco-dl20' + # 'msmarco-dev' + # 'nq-dev' + 'scifact-dev' + 'fiqa-dev' + 'hotpotqa-dev' + 'dbpedia-entity-dev' + 'quora-dev' + 'fever-dev' +) + + +evaluate() { + local dataset_name=$1 + local trec_file=$2 + local output_suffix=$3 + + # if the final eval file exists and has a score, skip + if [[ -f "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" ]]; then + # if ndcg_cut_10 is in it or recip in it, skip + if [[ $(grep -c "ndcg_cut_10" "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval") -gt 0 ]] || [[ $(grep -c "recip_rank" "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval") -gt 0 ]]; then + echo "Skipping ${dataset_name}${output_suffix} because of existing file ${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" + return + fi + fi + + echo "Evaluating ${dataset_name} with ${trec_file}..." + + # if it is not msmarco and not -dev in the name + if [[ "$dataset_name" != *"msmarco"* ]] && [[ "$dataset_name" != *"-dev"* ]]; then + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 \ + "beir-v1.0.0-${dataset_name}-test" \ + "${trec_file}" \ + > "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" + # else if -dev in the name and is not msmarco + elif [[ "$dataset_name" == *"-dev"* ]] && [[ "$dataset_name" != *"msmarco"* ]]; then + # remove the -dev + new_dataset_name=$(echo $dataset_name | sed 's/-dev//') + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 \ + resources/downloaded/qrels/$new_dataset_name.qrels.sampled \ + "${trec_file}" \ + > "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" + else + dataset=$(echo $dataset_name | cut -d'-' -f2) + if [ $dataset == "dev" ]; then + echo "Evaluating ${dataset}..." + echo "python -m pyserini.eval.trec_eval -c -M 100 -m recip_rank msmarco-passage-dev-subset "${trec_file}"" + python -m pyserini.eval.trec_eval -c -M 100 -m recip_rank msmarco-passage-dev-subset "${trec_file}" > "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" + else + pyserini_dataset="${dataset}-passage" + echo "Evaluating ${dataset}..." + echo "python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 $pyserini_dataset "${trec_file}"" + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 $pyserini_dataset "${trec_file}" > "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" + fi + fi + + + echo "Score is saved at ${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" + cat "${results_path}/${dataset_name}/${dataset_name}${output_suffix}.eval" +} + +# Process all datasets +for dataset in "${datasets[@]}"; do + dataset_path="${results_path}/${dataset}" + + # Evaluate without prompt + if [[ -f "${dataset_path}/${dataset}.trec" ]]; then + evaluate "$dataset" "${dataset_path}/${dataset}.trec" "" + fi + + # Evaluate with prompts + for trec_file in "${dataset_path}/${dataset}_"*.trec; do + if [[ -f "$trec_file" ]]; then + prompt_hash=$(basename "$trec_file" | sed -n 's/.*_\(.*\)\.trec/\1/p') + evaluate "$dataset" "$trec_file" "_${prompt_hash}" + fi + done +done + +# Aggregate results +echo "Aggregating results..." +output_file="${results_path}/bm25_aggregate_results.csv" +echo "Dataset,Prompt,NDCG@10,Recall@100,MRR" > "$output_file" + +for dataset in "${datasets[@]}"; do + dataset_path="${results_path}/${dataset}" + + # Process results without prompt + eval_file="${dataset_path}/${dataset}.eval" + if [[ -f "$eval_file" ]]; then + if [[ "$dataset" == "msmarco-dev" ]]; then + mrr=$(awk '/recip_rank / {print $3}' "$eval_file") + echo "${dataset},no_prompt,,,${mrr}" >> "$output_file" + else + ndcg=$(awk '/ndcg_cut_10 / {print $3}' "$eval_file") + recall=$(awk '/recall_100 / {print $3}' "$eval_file") + echo "${dataset},no_prompt,${ndcg},${recall}," >> "$output_file" + fi + fi + + # Process results with prompts + for eval_file in "${dataset_path}/${dataset}_"*.eval; do + if [[ "$dataset" == "msmarco-dev" ]]; then + prompt_hash=$(basename "$eval_file" | sed -n 's/.*_\(.*\)\.eval/\1/p') + mrr=$(awk '/recip_rank / {print $3}' "$eval_file") + echo "${dataset},${prompt_hash},,,${mrr}" >> "$output_file" + else + prompt_hash=$(basename "$eval_file" | sed -n 's/.*_\(.*\)\.eval/\1/p') + ndcg=$(awk '/ndcg_cut_10 / {print $3}' "$eval_file") + recall=$(awk '/recall_100 / {print $3}' "$eval_file") + echo "${dataset},${prompt_hash},${ndcg},${recall}" >> "$output_file" + fi + done +done + +echo "BM25 aggregate results saved to ${output_file}" \ No newline at end of file diff --git a/scripts/beir/clear_all_directories.sh b/scripts/beir/clear_all_directories.sh new file mode 100644 index 0000000..7145841 --- /dev/null +++ b/scripts/beir/clear_all_directories.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# +for dataset in scifact fiqa hotpotqa fever nq dbpedia quora nfcorpus; do + bash scripts/beir/clear_directory.sh reproduced-v2 $dataset-dev +done + +# also redownload all +python scripts/beir/force_redownload_all.py \ No newline at end of file diff --git a/scripts/beir/clear_directory.sh b/scripts/beir/clear_directory.sh new file mode 100644 index 0000000..1b05fe6 --- /dev/null +++ b/scripts/beir/clear_directory.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +model_name=$1 +dataset=$2 + +rm $model_name/$dataset/rank.${dataset}* +rm $model_name/$dataset/${dataset}* \ No newline at end of file diff --git a/scripts/beir/encode_beir_corpus.sh b/scripts/beir/encode_beir_corpus.sh new file mode 100644 index 0000000..95bbc86 --- /dev/null +++ b/scripts/beir/encode_beir_corpus.sh @@ -0,0 +1,75 @@ +#!/bin/bash +path_to_save=$1 +model=$2 +dataset_name=$3 +base_model=$4 + +mkdir -p logs +mkdir -p logs/inference + +# if base model is empty use llama2 +if [ -z "$base_model" ]; then + base_model="meta-llama/Llama-2-7b-hf" +fi + +echo "Base model: $base_model" +mkdir -p $path_to_save +# ps aux | grep "[p]ython -m tevatron.retriever.driver.encode" | awk '{print $2}' | xargs kill +echo "Encoding BEIR corpus... $dataset_name" +# reverse these +# reverse it +for s in $(seq 7 -1 0); +do + # gpu id += 4 + # gpuid=$((s+4)) + gpuid=$s + # echo $gpuid + # if it's the last one (aka zero), don't run in background + if [ "$s" == "0" ]; then + # give it some time so that it's the last to run + sleep 60 + CUDA_VISIBLE_DEVICES=$gpuid python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path $base_model \ + --lora_name_or_path $model \ + --lora \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --per_device_eval_batch_size 32 \ + --query_max_len 512 \ + --passage_max_len 512 \ + --dataset_name "Tevatron/beir-corpus" \ + --dataset_config "$dataset_name" \ + --dataset_split "train" \ + --dataset_number_of_shards 8 \ + --dataset_shard_index ${s} \ + --encode_output_path $path_to_save/corpus_emb.${s}.pkl > logs/inference/encode_corpus_${dataset_name}_${s}.log 2>&1 + else + CUDA_VISIBLE_DEVICES=$gpuid python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path $base_model \ + --lora_name_or_path $model \ + --lora \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --per_device_eval_batch_size 32 \ + --query_max_len 512 \ + --passage_max_len 512 \ + --dataset_name "Tevatron/beir-corpus" \ + --dataset_config "$dataset_name" \ + --dataset_split "train" \ + --dataset_number_of_shards 8 \ + --dataset_shard_index ${s} \ + --encode_output_path $path_to_save/corpus_emb.${s}.pkl > logs/inference/encode_corpus_${dataset_name}_${s}.log 2>&1 & + fi +done + +# bash scripts/beir/encode_beir_corpus.sh reproduced-v2/scifact orionweller/repllama-reproduced-v2 scifact \ No newline at end of file diff --git a/scripts/beir/encode_beir_queries.sh b/scripts/beir/encode_beir_queries.sh new file mode 100644 index 0000000..43d9eaa --- /dev/null +++ b/scripts/beir/encode_beir_queries.sh @@ -0,0 +1,83 @@ +#!/bin/bash +base_model=$1 +encoded_save_path=$2 +model=$3 +dataset_name=$4 +gpu_num=$5 +prompt=$6 + +mkdir -p $encoded_save_path/$dataset_name + +echo "Base model: $base_model" +echo "Encoded save path: $encoded_save_path" +echo "Model: $model" +echo "Dataset name: $dataset_name" +echo "GPU number: $gpu_num" +echo "Prompt: $prompt" + +if [ -z "$prompt" ]; then + prompt_flag=() + final_output_path="$encoded_save_path/${dataset_name}_queries_emb.pkl" +else + prompt_flag=(--prompt "$prompt") + prompt_hash=$(echo -n "$prompt" | md5sum | awk '{print $1}') + echo "Prompt hash: $prompt_hash for prompt $prompt" + echo "Prompt flag: ${prompt_flag[*]}" + final_output_path="$encoded_save_path/${dataset_name}_queries_emb_${prompt_hash}.pkl" +fi + + +# if final_output_path exists, skip +if [ -f "$final_output_path" ]; then + echo "Skipping $dataset_name because of existing file $final_output_path" + exit 0 +fi + +echo "Saving to $final_output_path" + + +### if msmarco is in the name of the dataset, the new dataset name is after the - and use the tevatron msmarco-passages dataset +if [[ "$dataset_name" == *"msmarco"* ]]; then + dataset=$(echo $dataset_name | cut -d'-' -f2) + CUDA_VISIBLE_DEVICES=$gpu_num python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path $base_model \ + --lora_name_or_path "$model" \ + --lora \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --encode_is_query \ + --per_device_eval_batch_size 16 \ + --query_max_len 512 \ + --passage_max_len 512 \ + --dataset_name Tevatron/msmarco-passage \ + --dataset_split $dataset \ + --encode_output_path "$final_output_path" "${prompt_flag[@]}" + +else + + CUDA_VISIBLE_DEVICES=$gpu_num python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path $base_model \ + --lora_name_or_path "$model" \ + --lora \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --encode_is_query \ + --per_device_eval_batch_size 16 \ + --query_max_len 512 \ + --passage_max_len 512 \ + --dataset_name orionweller/beir \ + --dataset_config "$dataset_name" \ + --dataset_split test \ + --encode_output_path "$final_output_path" "${prompt_flag[@]}" + +fi \ No newline at end of file diff --git a/scripts/beir/force_redownload_all.py b/scripts/beir/force_redownload_all.py new file mode 100644 index 0000000..c9566f6 --- /dev/null +++ b/scripts/beir/force_redownload_all.py @@ -0,0 +1,5 @@ +from datasets import load_dataset +# for dataset in ['NFCorpus', 'FiQA', 'Quora', 'DBPedia-Entity', 'SciFact', 'HotpotQA', "NQ", "FEVER"]: +for dataset in ["NQ"]: + ds = load_dataset("orionweller/beir", f"{dataset.lower()}-dev", trust_remote_code=True, download_mode="force_redownload") + print(ds) \ No newline at end of file diff --git a/scripts/beir/matrix_of_prompts.sh b/scripts/beir/matrix_of_prompts.sh new file mode 100644 index 0000000..bc04ca8 --- /dev/null +++ b/scripts/beir/matrix_of_prompts.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# infinite loop +# while true; do + # bash scripts/beir/run_all_prompts.sh orionweller/repllama-instruct-hard-positives-v2-joint joint-full + # bash scripts/beir/run_all_prompts.sh orionweller/repllama-reproduced-v2 reproduced-v2 + # bash scripts/beir/run_all_prompts.sh orionweller/repllama-instruct-llama3.1 llama3.1 meta-llama/Meta-Llama-3.1-8B + bash scripts/beir/run_all_prompts.sh orionweller/repllama-instruct-llama3.1-instruct llama3.1-instruct meta-llama/Meta-Llama-3.1-8B-Instruct + # redo climate-fever llama3.1 + # bash scripts/beir/bm25/prompt_all_bm25.sh + + # now do the search + # export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + # bash scripts/beir/search_all_prompts.sh joint-full + # bash scripts/beir/search_all_prompts.sh reproduced-v2 + # bash scripts/beir/search_all_prompts.sh llama3.1 + bash scripts/beir/search_all_prompts.sh llama3.1-instruct + # sleep 600 + # bash scripts/beir/bm25/search_all_bm25.sh bm25 +# done \ No newline at end of file diff --git a/scripts/beir/run_all.sh b/scripts/beir/run_all.sh new file mode 100644 index 0000000..9127f8e --- /dev/null +++ b/scripts/beir/run_all.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# example usage: bash scripts/beir/run_all.sh orionweller/repllama-reproduced-v2 reproduced-v2 +# example usage: bash scripts/beir/run_all.sh orionweller/repllama-instruct-hard-positives-v2-joint joint-full +# example usage: bash scripts/beir/run_all.sh orionweller/repllama-instruct-mistral-v0.1 mistral-v1 mistralai/Mistral-7B-v0.1 + +# bash scripts/beir/run_all.sh orionweller/repllama-instruct-llama3.1-instruct llama3.1-instruct meta-llama/Meta-Llama-3.1-8B-Instruct + +# export CUDA_VISIBLE_DEVICES="4,5,6,7" +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +retriever_name=$1 +nickname=$2 +base_model=$3 + +echo "retriever_name: $retriever_name" +echo "nickname: $nickname" +echo "base_model: $base_model" + +mkdir -p $nickname + +datasets=( + 'fiqa' + 'nfcorpus' + 'scidocs' + 'scifact' + 'trec-covid' + 'webis-touche2020' + 'quora' + 'arguana' + 'hotpotqa' + 'fever' + 'climate-fever' + 'dbpedia-entity' + # 'nq' +) + +for dataset in "${datasets[@]}"; do + # if the dataset already exists (corpus_emb.0.pkl exists), skip it + if [ -d "$nickname/$dataset" ] && [ -f "$nickname/$dataset/corpus_emb.0.pkl" ]; then + echo "Skipping $dataset" + continue + fi + echo "Encoding corpus: $dataset" + bash scripts/beir/encode_beir_corpus.sh $nickname/$dataset $retriever_name $dataset $base_model +done diff --git a/scripts/beir/run_all_prompts.sh b/scripts/beir/run_all_prompts.sh new file mode 100644 index 0000000..697eab4 --- /dev/null +++ b/scripts/beir/run_all_prompts.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +# example usage: +# bash scripts/beir/run_all_prompts.sh orionweller/repllama-reproduced-v2 reproduced-v2 +# bash scripts/beir/run_all_prompts.sh orionweller/repllama-instruct-hard-positives-v2-joint joint-full + +# export CUDA_VISIBLE_DEVICES="0,1,2,3" + +retriever_name=$1 +nickname=$2 +base_model=$3 + +# if base model is empty, use meta-llama/Llama-2-7b-hf +if [ -z "$base_model" ]; then + base_model="meta-llama/Llama-2-7b-hf" +fi + +echo "Retriever name: $retriever_name" +echo "Nickname: $nickname" +echo "Base model: $base_model" + +mkdir -p $nickname + +datasets=( + 'fiqa' + 'nfcorpus' + 'scidocs' + 'scifact' + 'trec-covid' + 'webis-touche2020' + 'quora' + 'nq' + 'arguana' + 'hotpotqa' + 'fever' + 'climate-fever' + 'dbpedia-entity' + 'nfcorpus-dev' + # 'nq-dev' + 'scifact-dev' + 'fiqa-dev' + 'hotpotqa-dev' + 'dbpedia-entity-dev' + 'quora-dev' + 'fever-dev' +) + + +# Read in each line of the generic_prompts.csv file where each line is a prompt +# Run it on each dataset, hashing the prompt and passing that as the fourth argument +gpu_num=0 +gpu_max=7 +while IFS= read -r prompt +do + for dataset in "${datasets[@]}"; do + echo "Running prompt on dataset: $dataset" + echo "Prompt: '$prompt'" + # if the gpu_num is the max, don't run it in the background, otherwise run in the background + if [ $gpu_num -eq $gpu_max ]; then + bash scripts/beir/encode_beir_queries.sh $base_model "$nickname/$dataset" "$retriever_name" "$dataset" "$gpu_num" "$prompt" + # echo "Sleeping for 120 seconds..." + # sleep 10 + # echo "Done sleeping." + else + bash scripts/beir/encode_beir_queries.sh $base_model "$nickname/$dataset" "$retriever_name" "$dataset" "$gpu_num" "$prompt" & + fi + # update the GPU num looping if it hits the max + gpu_num=$((gpu_num+1)) + if [ $gpu_num -gt $gpu_max ]; then + gpu_num=0 + fi + done +done < generic_prompts.csv + + +# also run one without a prompt for each dataset +for dataset in "${datasets[@]}"; do + echo "Running without prompt on dataset: $dataset" + bash scripts/beir/encode_beir_queries.sh $base_model "$nickname/$dataset" "$retriever_name" "$dataset" "$gpu_num" + # update the GPU num looping if it hits the max + gpu_num=$((gpu_num+1)) + if [ $gpu_num -gt $gpu_max ]; then + gpu_num=0 + fi +done diff --git a/scripts/beir/search_all_prompts.sh b/scripts/beir/search_all_prompts.sh new file mode 100644 index 0000000..1260266 --- /dev/null +++ b/scripts/beir/search_all_prompts.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +# example: bash scripts/beir/search_all_prompts.sh reproduced-v2 +# example: bash scripts/beir/search_all_prompts.sh joint-full + +save_path=$1 + +datasets=( + 'fiqa' + 'nfcorpus' + 'scidocs' + 'scifact' + 'trec-covid' + 'webis-touche2020' + 'quora' + 'nq' + 'arguana' + 'hotpotqa' + 'fever' + 'climate-fever' + 'dbpedia-entity' + 'nfcorpus-dev' + # 'nq-dev' + 'scifact-dev' + 'fiqa-dev' + 'hotpotqa-dev' + 'dbpedia-entity-dev' + 'quora-dev' + 'fever-dev' +) + + +search_and_evaluate() { + local dataset_name=$1 + local query_emb_file=$2 + local output_suffix=$3 + + # if the final eval file exists and has a score, skip + if [[ -f "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval" ]]; then + # if there exists an ndcg_cut_10 in the file or a recip_rank, skip + if [[ $(grep -c "ndcg_cut_10" "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval") -gt 0 ]] || [[ $(grep -c "recip_rank" "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval") -gt 0 ]]; then + echo "Skipping ${dataset_name}${output_suffix} because of existing file ${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval" + return + fi + fi + + echo "Searching and evaluating ${dataset_name} with ${query_emb_file}..." + + python -m tevatron.retriever.driver.search \ + --query_reps "${query_emb_file}" \ + --passage_reps "${save_path}/${dataset_name}/corpus_emb.*.pkl" \ + --batch_size 64 \ + --depth 1000 \ + --save_text \ + --save_ranking_to "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.txt" + + # if the last command failed, exit + if [ $? -ne 0 ]; then + echo "Failed to search ${dataset_name}${output_suffix}" + exit 1 + fi + + echo "Ranking is saved at ${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.txt" + + python -m tevatron.utils.format.convert_result_to_trec \ + --input "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.txt" \ + --output "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.trec" \ + --remove_query + + # if msmarco is not in the name use beir + echo "Evaluating ${dataset_name}${output_suffix}..." + if [[ "$dataset_name" != *"msmarco"* ]] && [[ "$dataset_name" != *"-dev"* ]]; then + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 \ + "beir-v1.0.0-${dataset_name}-test" \ + "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.trec" \ + > "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval" + # else if -dev in the name + elif [[ "$dataset_name" == *"-dev"* ]] && [[ "$dataset_name" != *"msmarco"* ]]; then + # remove the -dev + new_dataset_name=$(echo $dataset_name | sed 's/-dev//') + # echo "NEw dataset name: $new_dataset_name" + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 \ + resources/downloaded/qrels/$new_dataset_name.qrels.sampled \ + "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.trec" \ + > "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval" + else + dataset=$(echo $dataset_name | cut -d'-' -f2) + if [ $dataset == "dev" ]; then + echo "Evaluating ${dataset}..." + echo "python -m pyserini.eval.trec_eval -c -M 100 -m recip_rank msmarco-passage-dev-subset $save_path//${dataset_name}/rank.${dataset_name}${output_suffix}.trec" + python -m pyserini.eval.trec_eval -c -M 100 -m recip_rank msmarco-passage-dev-subset $save_path/${dataset_name}/rank.${dataset_name}${output_suffix}.trec > $save_path/${dataset_name}/rank.${dataset_name}${output_suffix}.eval + else + pyserini_dataset="${dataset}-passage" + echo "Evaluating ${dataset}..." + echo "python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 $pyserini_dataset $save_path/${dataset_name}/rank.${dataset_name}${output_suffix}.trec" + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 $pyserini_dataset $save_path/${dataset_name}/rank.${dataset_name}${output_suffix}.trec > $save_path/${dataset_name}/rank.${dataset_name}${output_suffix}.eval + fi + fi + echo "Score is saved at ${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval" + cat "${save_path}/${dataset_name}/rank.${dataset_name}${output_suffix}.eval" + sleep 5 +} + +# Process all datasets +for dataset in "${datasets[@]}"; do + dataset_path="${save_path}/${dataset}" + + # Search without prompt + if [[ -f "${dataset_path}/${dataset}_queries_emb.pkl" ]]; then + search_and_evaluate "$dataset" "${dataset_path}/${dataset}_queries_emb.pkl" "" + fi + + # Search with generic prompts + for query_file in "${dataset_path}/${dataset}_queries_emb_"*.pkl; do + if [[ -f "$query_file" ]]; then + prompt_hash=$(basename "$query_file" | sed -n 's/.*_emb_\(.*\)\.pkl/\1/p') + search_and_evaluate "$dataset" "$query_file" "_${prompt_hash}" + fi + done +done +#!/bin/bash + +# Aggregate results +echo "Aggregating results..." +output_file="${save_path}/aggregate_results.csv" +echo "Dataset,Prompt,NDCG@10,Recall@100,MRR" > "$output_file" + +for dataset in "${datasets[@]}"; do + dataset_path="${save_path}/${dataset}" + + # Process results without prompt + eval_file="${dataset_path}/rank.${dataset}.eval" + if [[ -f "$eval_file" ]]; then + if [[ "$dataset" == "msmarco-dev" ]]; then + mrr=$(awk '/recip_rank / {print $3}' "$eval_file") + echo "${dataset},no_prompt,,,${mrr}" >> "$output_file" + else + ndcg=$(awk '/ndcg_cut_10 / {print $3}' "$eval_file") + recall=$(awk '/recall_100 / {print $3}' "$eval_file") + echo "${dataset},no_prompt,${ndcg},${recall}," >> "$output_file" + fi + fi + + # Process results with prompts + for eval_file in "${dataset_path}/rank.${dataset}_"*.eval; do + if [[ -f "$eval_file" ]]; then + prompt_hash=$(basename "$eval_file" | sed -n 's/.*_\(.*\)\.eval/\1/p') + if [[ "$dataset" == "msmarco-dev" ]]; then + mrr=$(awk '/recip_rank / {print $3}' "$eval_file") + echo "${dataset},${prompt_hash},,,${mrr}" >> "$output_file" + else + ndcg=$(awk '/ndcg_cut_10 / {print $3}' "$eval_file") + recall=$(awk '/recall_100 / {print $3}' "$eval_file") + echo "${dataset},${prompt_hash},${ndcg},${recall}," >> "$output_file" + fi + fi + done +done + +echo "Aggregate results saved to ${output_file}" \ No newline at end of file diff --git a/scripts/beir/search_beir.sh b/scripts/beir/search_beir.sh new file mode 100644 index 0000000..6cbe962 --- /dev/null +++ b/scripts/beir/search_beir.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +save_path=$1 +dataset_name=$2 + +python -m tevatron.retriever.driver.search \ +--query_reps $save_path/${dataset_name}_queries_emb.pkl \ +--passage_reps "$save_path/"'corpus_emb.*.pkl' \ +--batch_size 1024 \ +--depth 1000 \ +--save_text \ +--save_ranking_to $save_path/rank.${dataset_name}.txt + +python -m tevatron.utils.format.convert_result_to_trec --input $save_path/rank.${dataset_name}.txt \ + --output $save_path/rank.${dataset_name}.trec \ + --remove_query + + +echo "Evaluating ${dataset_name}..." +echo "python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-${dataset_name}-test $save_path/rank.${dataset_name}.trec" +python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 beir-v1.0.0-${dataset_name}-test $save_path/rank.${dataset_name}.trec > $save_path/rank.${dataset_name}.eval +echo "Score is saved at $save_path/rank.${dataset_name}.eval" +cat $save_path/rank.${dataset_name}.eval diff --git a/scripts/error_analysis/error_analysis.py b/scripts/error_analysis/error_analysis.py new file mode 100644 index 0000000..48e4b93 --- /dev/null +++ b/scripts/error_analysis/error_analysis.py @@ -0,0 +1,212 @@ +import argparse +import json +import os +from collections import defaultdict +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pytrec_eval +import seaborn as sns +from tqdm import tqdm +import ir_datasets +from scipy import stats as scipy_stats + +import nltk +from nltk.tokenize import word_tokenize +from nltk.corpus import stopwords +from collections import Counter +from sklearn.feature_extraction.text import TfidfVectorizer + +# Download necessary NLTK data +# nltk.download('punkt') +# nltk.download('averaged_perceptron_tagger') +# nltk.download('stopwords') + +def load_run(file_path: str) -> Dict[str, Dict[str, float]]: + run = defaultdict(dict) + with open(file_path, 'r') as f: + for line in f: + query_id, _, doc_id, rank, score, _ = line.strip().split() + run[query_id][doc_id] = float(score) + return run + +def load_dataset(dataset_name: str) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]: + try: + dataset = ir_datasets.load(f'beir/{dataset_name}/test') + except Exception: + dataset = ir_datasets.load(f'beir/{dataset_name}') + + qrels = defaultdict(dict) + queries = {} + + for query in dataset.queries_iter(): + queries[query.query_id] = query.text + + for qrel in dataset.qrels_iter(): + qrels[qrel.query_id][qrel.doc_id] = qrel.relevance + + return qrels, queries + +def evaluate_runs(run1: Dict[str, Dict[str, float]], run2: Dict[str, Dict[str, float]], + qrels: Dict[str, Dict[str, int]]) -> Tuple[Dict[str, float], Dict[str, float]]: + evaluator = pytrec_eval.RelevanceEvaluator(qrels, {'ndcg_cut.10'}) + + results1 = evaluator.evaluate(run1) + results2 = evaluator.evaluate(run2) + + return results1, results2 + +def compare_runs(results1: Dict[str, Dict[str, float]], results2: Dict[str, Dict[str, float]]) -> List[Dict]: + comparison = [] + for query_id in results1.keys(): + score1 = results1[query_id]['ndcg_cut_10'] + score2 = results2[query_id]['ndcg_cut_10'] + diff = score2 - score1 + if diff == 0: + label = 'tie' + continue + else: + label = 'run2' if diff > 0 else 'run1' + comparison.append({ + 'query_id': query_id, + 'diff': diff, + 'label': label + }) + + # sort comparison by label + comparison.sort(key=lambda x: x['label']) + return comparison + +def save_jsonl(data: List[Dict], output_file: str): + with open(output_file, 'w') as f: + for item in data: + json.dump(item, f) + f.write('\n') + + +def plot_statistics(stats: Dict[str, Dict[str, List[float]]], output_dir: str): + os.makedirs(output_dir, exist_ok=True) + + for stat_name, stat_data in stats.items(): + plt.figure(figsize=(10, 6)) + sns.boxplot(data=[stat_data['run1'], stat_data['run2']]) + plt.title(f'{stat_name.replace("_", " ").title()} Distribution') + plt.xticks([0, 1], ['Run 1', 'Run 2']) + plt.ylabel(stat_name.replace('_', ' ').title()) + plt.savefig(os.path.join(output_dir, f'{stat_name}_boxplot.png')) + plt.close() + + # Add histogram for this statistic + plt.figure(figsize=(10, 6)) + sns.histplot(data=stat_data, kde=True) + plt.title(f'{stat_name.replace("_", " ").title()} Distribution') + plt.xlabel(stat_name.replace('_', ' ').title()) + plt.ylabel('Frequency') + plt.savefig(os.path.join(output_dir, f'{stat_name}_histogram.png')) + plt.close() + + + +def compute_statistics(queries: Dict[str, str], comparison: List[Dict], results1: Dict[str, Dict[str, float]], results2: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, List[float]]]: + stats = defaultdict(lambda: defaultdict(list)) + stop_words = set(stopwords.words('english')) + + for item in comparison: + query = queries[item['query_id']] + label = item['label'] + + # Existing statistics + stats['length'][label].append(len(query)) + stats['question_marks'][label].append(query.count('?')) + stats['exclamation_marks'][label].append(query.count('!')) + stats['commas'][label].append(query.count(',')) + stats['word_count'][label].append(len(query.split())) + + # New statistics + words = word_tokenize(query.lower()) + words_no_stop = [w for w in words if w not in stop_words] + + # Average word length + stats['avg_word_length'][label].append(np.mean([len(w) for w in words])) + + # Unique words + stats['unique_words'][label].append(len(set(words))) + + # Part-of-speech distribution + pos_tags = nltk.pos_tag(words) + pos_counts = Counter(tag for word, tag in pos_tags) + stats['noun_count'][label].append(pos_counts.get('NN', 0) + pos_counts.get('NNS', 0)) + stats['verb_count'][label].append(pos_counts.get('VB', 0) + pos_counts.get('VBD', 0) + pos_counts.get('VBG', 0)) + stats['adj_count'][label].append(pos_counts.get('JJ', 0)) + + # Performance-based metrics + stats['abs_diff'][label].append(abs(item['diff'])) + stats['rel_improvement'][label].append((item['diff'] / results1[item['query_id']]['ndcg_cut_10']) * 100 if results1[item['query_id']]['ndcg_cut_10'] != 0 else 0) + + # Query difficulty estimation (across all queries) + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform([queries[item['query_id']] for item in comparison]) + query_idf = np.array(tfidf_matrix.sum(axis=0)).flatten() + + for i, item in enumerate(comparison): + label = item['label'] + stats['avg_idf'][label].append(np.mean(query_idf[i])) + + return stats + +def main(args): + # make the output directories and plots + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(os.path.join(args.output_dir, 'plots'), exist_ok=True) + + # Load data + run1 = load_run(args.run1) + run2 = load_run(args.run2) + qrels, queries = load_dataset(args.dataset_name) + + # Evaluate runs + results1, results2 = evaluate_runs(run1, run2, qrels) + + # Compare runs + comparison = compare_runs(results1, results2) + + # Add query text to comparison + for item in comparison: + item['query'] = queries[item['query_id']] + + # Save comparison to JSONL file + save_jsonl(comparison, os.path.join(args.output_dir, 'comparison.jsonl')) + + # Compute statistics + stats = compute_statistics(queries, comparison, results1, results2) + + # Plot statistics + plot_statistics(stats, os.path.join(args.output_dir, 'plots')) + + # Print summary statistics + print("Summary Statistics:") + for stat_name, stat_data in stats.items(): + print(f"\n{stat_name.replace('_', ' ').title()}:") + for run, values in stat_data.items(): + print(f" {run}: Mean = {np.mean(values):.2f}, Median = {np.median(values):.2f}, " + f"Min = {np.min(values):.2f}, Max = {np.max(values):.2f}") + + # Perform statistical significance tests + for stat_name, stat_data in stats.items(): + if 'run1' in stat_data and 'run2' in stat_data: + t_stat, p_value = scipy_stats.ttest_ind(stat_data['run1'], stat_data['run2']) + print(f"\n{stat_name.replace('_', ' ').title()} - T-test results:") + print(f" T-statistic: {t_stat:.4f}") + print(f" P-value: {p_value:.4f}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Compare two IR model run files") + parser.add_argument("run1", help="Path to the first run file") + parser.add_argument("run2", help="Path to the second run file") + parser.add_argument("dataset_name", help="Name of the dataset") + parser.add_argument("output_dir", help="Path to the output folder") + args = parser.parse_args() + + main(args) + # example usage: python scripts/error_analysis.py joint-full/scifact/rank.scifact.trec reproduced-v2/scifact/rank.scifact.trec scifact error_analysis/ \ No newline at end of file diff --git a/scripts/error_analysis/error_analysis_bow.py b/scripts/error_analysis/error_analysis_bow.py new file mode 100644 index 0000000..a22bc73 --- /dev/null +++ b/scripts/error_analysis/error_analysis_bow.py @@ -0,0 +1,145 @@ +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.naive_bayes import MultinomialNB +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, confusion_matrix, classification_report +import seaborn as sns + +# Load and preprocess the data +def load_data(file_path): + data = [] + with open(file_path, 'r') as f: + for line in f: + item = json.loads(line) + data.append({ + 'text': item['query'], + 'label': 0 if item['label'] == 'run1' else 1 + }) + return data + +data = load_data('error_analysis/comparison.jsonl') + +# Split the data into train and test sets +train_data, test_data = train_test_split(data, test_size=0.2, random_state=42) + +# Prepare the data +X_train = [item['text'] for item in train_data] +y_train = [item['label'] for item in train_data] +X_test = [item['text'] for item in test_data] +y_test = [item['label'] for item in test_data] + +# Create and train the BoW classifier +vectorizer = CountVectorizer() +X_train_vec = vectorizer.fit_transform(X_train) +X_test_vec = vectorizer.transform(X_test) + +bow_classifier = MultinomialNB() +bow_classifier.fit(X_train_vec, y_train) + +# Evaluate the classifier +y_pred = bow_classifier.predict(X_test_vec) +accuracy = accuracy_score(y_test, y_pred) +print(f"Bag of Words Classifier Accuracy: {accuracy:.4f}") + +# Print classification report +print("\nClassification Report:") +print(classification_report(y_test, y_pred, target_names=['run1', 'run2'])) + +# Confusion Matrix +cm = confusion_matrix(y_test, y_pred) +plt.figure(figsize=(8, 6)) +sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') +plt.title('Confusion Matrix') +plt.ylabel('True Label') +plt.xlabel('Predicted Label') +plt.savefig('error_analysis/bow_confusion_matrix.png') +plt.close() + +# Feature Importance +feature_importance = bow_classifier.feature_log_prob_[1] - bow_classifier.feature_log_prob_[0] +feature_names = vectorizer.get_feature_names_out() + +# Top N most important features +N = 20 +top_features = sorted(zip(feature_names, feature_importance), key=lambda x: abs(x[1]), reverse=True)[:N] + +plt.figure(figsize=(12, 8)) +plt.barh([f[0] for f in top_features], [f[1] for f in top_features]) +plt.title(f'Top {N} Most Important Features') +plt.xlabel('Log Probability Difference') +plt.ylabel('Features') +plt.savefig('error_analysis/bow_feature_importance.png') +plt.close() + +# Sample Explanation +def explain_prediction(classifier, vectorizer, text): + x = vectorizer.transform([text]) + proba = classifier.predict_proba(x)[0] + feature_names = vectorizer.get_feature_names_out() + feature_values = x.toarray()[0] + + feature_importance = classifier.feature_log_prob_[1] - classifier.feature_log_prob_[0] + sorted_idx = feature_importance.argsort() + + top_positive = [(feature_names[i], feature_importance[i]) for i in sorted_idx[-10:] if feature_values[i] > 0] + top_negative = [(feature_names[i], feature_importance[i]) for i in sorted_idx[:10] if feature_values[i] > 0] + + return proba, top_positive, top_negative + +sample_idx = 0 +sample_text = X_test[sample_idx] +true_label = y_test[sample_idx] +pred_label = y_pred[sample_idx] + +print(f"\nSample Text: {sample_text}") +print(f"True Label: {'run1' if true_label == 0 else 'run2'}") +print(f"Predicted Label: {'run1' if pred_label == 0 else 'run2'}") + +proba, top_positive, top_negative = explain_prediction(bow_classifier, vectorizer, sample_text) +print(f"Probability: run1 - {proba[0]:.4f}, run2 - {proba[1]:.4f}") +print("\nTop positive features:") +for feature, importance in top_positive: + print(f"{feature}: {importance:.4f}") +print("\nTop negative features:") +for feature, importance in top_negative: + print(f"{feature}: {importance:.4f}") + +# Analysis of misclassifications +misclassified = [(X_test[i], y_test[i], y_pred[i]) for i in range(len(y_test)) if y_test[i] != y_pred[i]] + +print("\nSample Misclassifications:") +for text, true_label, pred_label in misclassified[:5]: + print(f"Text: {text}") + print(f"True Label: {'run1' if true_label == 0 else 'run2'}") + print(f"Predicted Label: {'run1' if pred_label == 0 else 'run2'}") + print("---") + +# Length analysis +train_lengths = [len(text.split()) for text in X_train] +test_lengths = [len(text.split()) for text in X_test] + +plt.figure(figsize=(10, 5)) +plt.hist(train_lengths, bins=20, alpha=0.5, label='Train') +plt.hist(test_lengths, bins=20, alpha=0.5, label='Test') +plt.title('Distribution of Text Lengths') +plt.xlabel('Number of Words') +plt.ylabel('Frequency') +plt.legend() +plt.savefig('error_analysis/text_lengths.png') +plt.close() + +# Class distribution +class_distribution = pd.Series(y_train + y_test).value_counts() +plt.figure(figsize=(8, 6)) +class_distribution.plot(kind='bar') +plt.title('Class Distribution') +plt.xlabel('Class') +plt.ylabel('Count') +plt.xticks([0, 1], ['run1', 'run2']) +plt.savefig('error_analysis/class_distribution.png') +plt.close() + +print("\nExplainability Analysis Complete") \ No newline at end of file diff --git a/scripts/error_analysis/error_analysis_modeling.py b/scripts/error_analysis/error_analysis_modeling.py new file mode 100644 index 0000000..dff326c --- /dev/null +++ b/scripts/error_analysis/error_analysis_modeling.py @@ -0,0 +1,100 @@ +import json +import numpy as np +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.naive_bayes import MultinomialNB +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split +from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments +from datasets import Dataset + +# Load and preprocess the data +def load_data(file_path): + data = [] + with open(file_path, 'r') as f: + for line in f: + item = json.loads(line) + data.append({ + 'text': item['query'], + 'label': 0 if item['label'] == 'run1' else 1 + }) + return data + +data = load_data('error_analysis/comparison.jsonl') + +# Split the data into train and test sets +train_data, test_data = train_test_split(data, test_size=0.2, random_state=42) + +# Create Hugging Face Datasets +train_dataset = Dataset.from_dict({ + 'text': [item['text'] for item in train_data], + 'label': [item['label'] for item in train_data] +}) + +test_dataset = Dataset.from_dict({ + 'text': [item['text'] for item in test_data], + 'label': [item['label'] for item in test_data] +}) + +# BERT Model Training +model_name = 'bert-base-uncased' +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) + +def tokenize_function(examples): + return tokenizer(examples['text'], padding='max_length', truncation=True) + +tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True) +tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True) + +training_args = TrainingArguments( + output_dir='./results', + num_train_epochs=3, + per_device_train_batch_size=8, + per_device_eval_batch_size=8, + warmup_steps=500, + weight_decay=0.01, + logging_dir='./logs', + evaluation_strategy="epoch", + fp16=True, +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_train_dataset, + eval_dataset=tokenized_test_dataset, +) + +# Train the model +trainer.train() + +# Evaluate the BERT model +bert_eval_results = trainer.evaluate() +print(f"BERT Model Evaluation Results: {bert_eval_results}") + +# Calculate accuracy manually +predictions = trainer.predict(tokenized_test_dataset) +bert_predictions = np.argmax(predictions.predictions, axis=1) +bert_accuracy = accuracy_score(test_dataset['label'], bert_predictions) +print(f"BERT Model Accuracy on Test Set: {bert_accuracy}") + +# Bag of Words Classifier +vectorizer = CountVectorizer() +X_train = vectorizer.fit_transform([item['text'] for item in train_data]) +y_train = [item['label'] for item in train_data] + +X_test = vectorizer.transform([item['text'] for item in test_data]) +y_test = [item['label'] for item in test_data] + +bow_classifier = MultinomialNB() +bow_classifier.fit(X_train, y_train) + +bow_predictions = bow_classifier.predict(X_test) +bow_accuracy = accuracy_score(y_test, bow_predictions) +print(f"Bag of Words Classifier Accuracy on Test Set: {bow_accuracy}") + +# majority class baseline +majority_class = max(set(y_test), key=y_test.count) +majority_class_predictions = [majority_class] * len(y_test) +majority_class_accuracy = accuracy_score(y_test, majority_class_predictions) +print(f"Majority Class Baseline Accuracy on Test Set: {majority_class_accuracy}") \ No newline at end of file diff --git a/scripts/filtering/filter_query_doc_pairs_from_batch_gpt.py b/scripts/filtering/filter_query_doc_pairs_from_batch_gpt.py new file mode 100644 index 0000000..f38af81 --- /dev/null +++ b/scripts/filtering/filter_query_doc_pairs_from_batch_gpt.py @@ -0,0 +1,136 @@ +import argparse +import os +import json +import pandas as pd +import tqdm +from datasets import load_dataset + + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +import os +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, +) +import torch +from peft import PeftModel, PeftConfig + + +TEMPLATE = """ [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices. + +Query: {query} +Document: {text} +Relevant (only output one word, either "true" or "false"): [/INST] """ + + + +def load_followir(model_name: str = "jhu-clsp/FollowIR-7B"): + print(f"Loading model {model_name}") + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16 + ).cuda() + tokenizer = AutoTokenizer.from_pretrained( + model_name, padding_side="left", fast=True + ) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + token_false_id = tokenizer.get_vocab()["false"] + token_true_id = tokenizer.get_vocab()["true"] + tokenizer.model_max_length = 768 + model.config.max_length = 768 + return model, tokenizer, token_true_id, token_false_id + + + + +def rank_batch_followir(queries, passages, tokenizer, model, true_token_id, false_token_id): + assert len(queries) == len(passages) + + prompts = [ + TEMPLATE.format(query=query, text=text) for (query, text) in zip(queries, passages) + ] + inputs = tokenizer( + prompts, + padding=True, + truncation=True, + return_tensors="pt", + pad_to_multiple_of=None, + ) + + model = model.cuda() + with torch.no_grad(): + inputs = {k: v.cuda() for k, v in inputs.items()} + # calculate the scores by comparing true and false tokens + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, true_token_id] + false_vector = batch_scores[:, false_token_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + + +def get_doc(doc_dict: dict) -> str: + return (doc_dict.get("title", "") + " " + doc_dict.get("text", "")).strip() + + +def filter_query_doc_pairs(args): + print(f"Opening output file {args.output_file}...") + out_file = open(args.output_file, "w") + + # load batch output file + print(f"Loading batch output file {args.batch_input}...") + input_data = [] + with open(args.batch_input, "r") as f: + for line in f: + input_data.append(json.loads(line)) + + + print(f"Loading model...") + model, tokenizer, true_token_id, false_token_id = load_followir() + + print(f"Scoring...") + batch_size = args.batch_size + num_batches = len(input_data) // batch_size + 1 + for idx, batch in tqdm.tqdm(enumerate(range(0, len(input_data), batch_size)), total=num_batches): + if args.debug and j > 10: + break + + # get the batch + batch_data = input_data[batch:batch+batch_size] + batch_queries = [d["query"] + " " + d["instruction"] for d in batch_data] + passages = [get_doc(d["passage"]) for d in batch_data] + doc_ids = [d["joint_id"] for d in batch_data] + scores = rank_batch_followir(batch_queries, passages, tokenizer, model, true_token_id, false_token_id) + + # cache each one by appending to the output_file + for i, (doc_id, score) in enumerate(zip(doc_ids, scores)): + out_file.write(f"{doc_id}\t{score:.3f}\n") + # flush it + out_file.flush() + + out_file.close() + + # read it and print stats + df = pd.read_csv(args.output_file, sep="\t", names=["doc_id", "score"], index_col=None, header=None) + print(f"Output file has {len(df)} rows") + df["group"] = df["doc_id"].apply(lambda x: x.split("_")[-1] if "_" in x else "real") + df["pred_label"] = df["score"].apply(lambda x: 1 if x > 0.5 else 0) + # print value counts for each one of the pred_label grouped by group + print(df.groupby(["group", "pred_label"]).size()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--batch_input", type=str, required=True) + parser.add_argument("-b", "--batch_size", type=int, default=8) + parser.add_argument("-o", "--output_file", type=str, required=True) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + filter_query_doc_pairs(args) + + # example usage + # python scripts/filter_query_doc_pairs_from_batch_gpt -i batch_outputs/batch_instances_Y57xfvrFKYSyxp0SSXIaJXUa.jsonl -o batch_outputs/followir_batch_scores_Y57xfvrFKYSyxp0SSXIaJXUa.tsv + \ No newline at end of file diff --git a/scripts/msmarco/encode_corpus.sh b/scripts/msmarco/encode_corpus.sh new file mode 100644 index 0000000..729b240 --- /dev/null +++ b/scripts/msmarco/encode_corpus.sh @@ -0,0 +1,41 @@ +#!/bin/bash +path_to_save=$1 +model=$2 +model_type=$3 + +# if model type is empty then use llama2 +if [ -z "$model_type" ] +then + model_type="meta-llama/Llama-2-7b-hf" +fi + +echo "Path to save: $path_to_save" +echo "Model and Model Type: $model $model_type" + +mkdir -p $path_to_save +for s in $(seq -f "%01g" 0 7) +do +# add 4 to the gpu_id +# gpuid=$((s+4)) +gpuid=$s +echo $gpuid +CUDA_VISIBLE_DEVICES=$gpuid python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path $model_type \ + --lora_name_or_path $model \ + --lora \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --per_device_eval_batch_size 128 \ + --query_max_len 32 \ + --passage_max_len 156 \ + --dataset_name Tevatron/msmarco-passage-corpus \ + --dataset_number_of_shards 8 \ + --dataset_shard_index ${s} \ + --encode_output_path $path_to_save/corpus_emb.${s}.pkl > logs/msmarco_encode_corpus_$s.log 2>&1 & +done + diff --git a/scripts/msmarco/encode_queries.sh b/scripts/msmarco/encode_queries.sh new file mode 100644 index 0000000..5baa234 --- /dev/null +++ b/scripts/msmarco/encode_queries.sh @@ -0,0 +1,36 @@ +#!/bin/bash +encoded_save_path=$1 +model=$2 +model_type=$3 + +# if model type is empty then use llama2 +if [ -z "$model_type" ] +then + model_type="meta-llama/Llama-2-7b-hf" +fi + +echo "Path to save: $encoded_save_path" +echo "Model and Model Type: $model $model_type" + +for dataset in dl19 dl20 dev; do +echo $dataset +CUDA_VISIBLE_DEVICES=0 python -m tevatron.retriever.driver.encode \ + --output_dir=temp \ + --model_name_or_path $model_type \ + --lora_name_or_path $model \ + --lora \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --encode_is_query \ + --per_device_eval_batch_size 128 \ + --query_max_len 32 \ + --passage_max_len 156 \ + --dataset_name Tevatron/msmarco-passage \ + --dataset_split $dataset \ + --encode_output_path $encoded_save_path/${dataset}_queries_emb.pkl +done + diff --git a/scripts/msmarco/search.sh b/scripts/msmarco/search.sh new file mode 100644 index 0000000..5621ac5 --- /dev/null +++ b/scripts/msmarco/search.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# sleep 14400 && bash scripts/search.sh retriever-llama2 +save_path=$1 +for dataset in dl19 dl20 dev; do # dev + + python -m tevatron.retriever.driver.search \ + --query_reps $save_path/${dataset}_queries_emb.pkl \ + --passage_reps "$save_path/"'corpus_emb.*.pkl' \ + --depth 1000 \ + --batch_size 64 \ + --save_text \ + --save_ranking_to $save_path/rank.${dataset}.txt + + python -m tevatron.utils.format.convert_result_to_trec --input $save_path/rank.${dataset}.txt \ + --output $save_path/rank.${dataset}.trec \ + --remove_query + + # if dataset is dev use msmarco-passage-dev-subset + if [ $dataset == "dev" ]; then + echo "Evaluating ${dataset}..." + echo "python -m pyserini.eval.trec_eval -c -M 100 -m recip_rank msmarco-passage-dev-subset $save_path/rank.${dataset}.trec" + python -m pyserini.eval.trec_eval -c -M 100 -m recip_rank msmarco-passage-dev-subset $save_path/rank.dev.trec > $save_path/rank.dev.eval + else + pyserini_dataset="${dataset}-passage" + echo "Evaluating ${dataset}..." + echo "python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 $pyserini_dataset $save_path/rank.${dataset}.trec" + + python -m pyserini.eval.trec_eval -c -mrecall.100 -mndcg_cut.10 $pyserini_dataset $save_path/rank.${dataset}.trec > $save_path/rank.${dataset}.eval + fi + + +done diff --git a/scripts/plotting/gather_results.py b/scripts/plotting/gather_results.py new file mode 100644 index 0000000..4ac4445 --- /dev/null +++ b/scripts/plotting/gather_results.py @@ -0,0 +1,86 @@ +import os +import csv +import re + +def extract_scores(file_path): + with open(file_path, 'r') as file: + content = file.read() + scores = {} + + patterns = { + 'recall@100': r'recall_100\s+all\s+([\d.]+)', + 'ndcg@10': r'ndcg_cut_10\s+all\s+([\d.]+)', + 'mrr': r'recip_rank\s+all\s+([\d.]+)' + } + + for metric, pattern in patterns.items(): + match = re.search(pattern, content) + if match: + scores[metric] = float(match.group(1)) * 100 # Convert to percentage + + return scores + +def extract_hash_from_filename(filename): + # Split the filename by underscore and take the last part before .eval + parts = filename.split('_') + if len(parts) > 1: + return parts[-1].split('.')[0] + return 'none' + +def process_directory(directory): + results = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith('.eval'): + file_path = os.path.join(root, file) + dataset_name = os.path.basename(file).split("_")[0].replace("rank.", "").replace(".eval", "") + + prompt_hash = extract_hash_from_filename(file) + + scores = extract_scores(file_path) + + result = { + 'dataset': dataset_name, + 'prompt_hash': prompt_hash, + 'filename': file, + **scores + } + + results.append(result) + + return results + +def write_to_csv(results, output_file): + if not results: + print(f"No results found for {output_file}. Skipping CSV creation.") + return + + fieldnames = set(['dataset', 'prompt_hash']) + for result in results: + fieldnames.update(result.keys()) + + fieldnames = sorted(list(fieldnames)) + + with open(output_file, 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for row in results: + writer.writerow(row) + +def main(): + directories = ['joint-full', 'bm25', "reproduced-v2"] + results_folder = 'results' + + # Create results folder if it doesn't exist + if not os.path.exists(results_folder): + os.makedirs(results_folder) + + for directory in directories: + results = process_directory(directory) + output_file = os.path.join(results_folder, f'{os.path.basename(directory)}_results.csv') + write_to_csv(results, output_file) + print(f"Results for {directory} written to {output_file}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/plotting/get_sd_table.py b/scripts/plotting/get_sd_table.py new file mode 100644 index 0000000..015f601 --- /dev/null +++ b/scripts/plotting/get_sd_table.py @@ -0,0 +1,119 @@ +import pandas as pd +import numpy as np +import glob +from collections import defaultdict + +SKIP_OLD_HASHES = [ + "0ab0de14665a035b4ce74ea58f0aeb0b", # + # "d2b1fa425e0198eb5ba2f9ceaa946389", + # "bc8581d1f8b9b223247df82aa13707fc", + # "b09133128f72179896830b2f10a6fa9e", + "11c51cdccc21293fad66b37e75bbdc94", + # "eeee229082555a0f22c493370c12651e", + "476c48e5591c52d8000c65bc88421652" # remove it, match key phrases +] + + +PRETTY_NAMES = { + "arguana": "Arguana", + "climate-fever": "Climate-FEVER", + "dbpedia-entity": "DBPedia", + "fever": "FEVER", + "fiqa": "FiQA", + "hotpotqa": "HotpotQA", + "nfcorpus": "NFCorpus", + "nq": "NQ", + "quora": "Quora", + "scidocs": "SCIDOCS", + "scifact": "SciFact", + "trec-covid": "TREC-COVID", + "webis-touche2020": "Touche-2020" +} + +def process_file(file_path): + # Read the CSV file + df = pd.read_csv(file_path) + + # Filter out MSMarco and -dev datasets + df = df[~df['dataset'].str.contains('msmarco|dev', case=False, na=False)] + + # remove old hashes + df = df[~df['prompt_hash'].isin(SKIP_OLD_HASHES)] + + # Group by dataset and calculate standard deviation of ndcg@10 + std_devs = df.groupby('dataset')['ndcg@10'].std() + + # # print the max score;s prompt hash of each dataset, excluding none + # for dataset in df['dataset'].unique(): + # # get the scores and prompt hashes for this dataset + # cur_set = df[df['dataset'] == dataset] + # # find the max score index + # max_score_index = cur_set['ndcg@10'].idxmax() + # # get the prompt hash for the max score + # max_prompt_hash = cur_set.loc[max_score_index, 'prompt_hash'] + # print(f"Max prompt hash for {dataset}: {max_prompt_hash}") + + # Calculate average standard deviation + avg_std_dev = std_devs.mean() + + return std_devs, avg_std_dev + +def create_latex_table(results_bm25, results_reproduced, results_joint): + latex_table = """ +\\begin{table}[h] +\\centering +\\begin{tabular}{lcc} +\\hline +Dataset & BM25 SD & Joint-Full SD \\\\ +\\hline +""" + + # Combine and sort datasets by average standard deviation in descending order + all_datasets = set(results_bm25.keys()) | set(results_joint.keys()) | set(results_reproduced.keys()) + all_datasets.remove('Average') + # sort the datasets by name + sorted_datasets = sorted(all_datasets) + + for dataset in sorted_datasets: + bm25_sd = results_bm25.get(dataset, 0) + joint_sd = results_joint.get(dataset, 0) + reproduced_sd = results_reproduced.get(dataset, 0) + pretty_name = PRETTY_NAMES.get(dataset, dataset) + latex_table += f"{pretty_name} & {bm25_sd:.1f} & {reproduced_sd:.1f} & {joint_sd:.1f} \\\\\n" + + latex_table += "\\hline\n" + latex_table += f"Average & {results_bm25['Average']:.1f} & {results_reproduced['Average']:.1f} & {results_joint['Average']:.1f} \\\\\n" + latex_table += "\\hline\n" + latex_table += "\\end{tabular}\n" + latex_table += "\\caption{Standard Deviations of NDCG@10 Scores Across Prompt Hashes}\n" + latex_table += "\\label{tab:std_devs}\n" + latex_table += "\\end{table}" + + return latex_table + +def main(file_paths): + results = {} + + print(file_paths) + + for file_path in file_paths: + std_devs, avg_std_dev = process_file(file_path) + if "bm25" in file_path.lower(): + model_name = 'BM25' + elif "reproduced-v2" in file_path.lower(): + model_name = 'Reproduced-v2' + elif "joint-full" in file_path.lower(): + model_name = 'Joint-Full' + results[model_name] = dict(std_devs) + results[model_name]['Average'] = avg_std_dev + + # Create LaTeX table + latex_table = create_latex_table(results['BM25'], results["Reproduced-v2"], results['Joint-Full']) + + print(latex_table) + +# Usage +# read in file paths from results/*.csv +file_paths = list(glob.glob("results/*.csv")) +print(file_paths) +main(file_paths) \ No newline at end of file diff --git a/scripts/plotting/make_prompt_all_table.py b/scripts/plotting/make_prompt_all_table.py new file mode 100644 index 0000000..742c62c --- /dev/null +++ b/scripts/plotting/make_prompt_all_table.py @@ -0,0 +1,73 @@ +import csv +from collections import defaultdict +import pandas as pd + +# Read the CSV data +data = defaultdict(lambda: defaultdict(float)) +datasets = set() +prompts = set() + +PRETTY_NAMES = { + "arguana": "ARG", + "climate-fever": "CFV", + "dbpedia-entity": "DBP", + "fever": "FEV", + "fiqa": "FQA", + "hotpotqa": "HQA", + "nfcorpus": "NFC", + "nq": "NQ", + "quora": "QUO", + "scidocs": "SCD", + "scifact": "SCF", + "trec-covid": "COV", + "webis-touche2020": "TOU" +} + +generic_prompts = pd.read_csv("results/generic_prompts.csv_hashes.csv", index_col=None) +generic_prompt_hashes = generic_prompts['prompt_hash'].to_list() +generic_hash_to_text = dict(zip(generic_prompts['prompt_hash'], generic_prompts['prompt'])) + +with open('results/joint-full_results.csv', 'r') as f: + reader = csv.DictReader(f) + for row in reader: + dataset = row['dataset'] + + # Skip msmarco datasets and those with "-dev" in the name + if 'msmarco' in dataset or '-dev' in dataset: + continue + + filename = row['filename'] + prompt = filename.split('_')[-1].split('.')[0] if '_' in filename else 'none' + if prompt not in generic_prompt_hashes: + continue + + # Use ndcg@10 if available, otherwise use mrr + score = float(row['ndcg@10']) if row['ndcg@10'] else float(row['mrr']) + + data[prompt][dataset] = score + datasets.add(dataset) + prompts.add(prompt) + +# Sort datasets and prompts +sorted_datasets = sorted(datasets) +sorted_prompts = sorted(prompts) +# sorted_prompts = [generic_hash_to_text[hash] for hash in sorted_prompts] + +# Generate LaTeX table +latex_table = "\\begin{table}[h]\n\\centering\n\\begin{tabular}{l" + "c" * len(sorted_datasets) + "}\n" +latex_table += "\\hline\nPrompt & " + " & ".join([PRETTY_NAMES[dataset] for dataset in sorted_datasets]) + " \\\\\n\\hline\n" + +for prompt in sorted_prompts: + row = [generic_hash_to_text[prompt]] + for dataset in sorted_datasets: + score = data[prompt][dataset] + row.append(f"{score:.1f}" if score != 0 else "-") + latex_table += " & ".join(row) + " \\\\\n" + +latex_table += "\\hline\n\\end{tabular}\n\\caption{Dataset scores for different prompts}\n\\label{tab:dataset_scores}\n\\end{table}" + +# Write the LaTeX table to a file +with open('dataset_scores_table.tex', 'w') as f: + f.write(latex_table) + +print("LaTeX table has been generated and saved to 'dataset_scores_table.tex'") \ No newline at end of file diff --git a/scripts/plotting/make_prompt_table_from_results.py b/scripts/plotting/make_prompt_table_from_results.py new file mode 100644 index 0000000..0ad0560 --- /dev/null +++ b/scripts/plotting/make_prompt_table_from_results.py @@ -0,0 +1,251 @@ +import csv +from collections import defaultdict +import statistics +import os +import random + + +SKIP_OLD_HASHES = [ + "0ab0de14665a035b4ce74ea58f0aeb0b", # + # "d2b1fa425e0198eb5ba2f9ceaa946389", + # "bc8581d1f8b9b223247df82aa13707fc", + # "b09133128f72179896830b2f10a6fa9e", + "11c51cdccc21293fad66b37e75bbdc94", + # "eeee229082555a0f22c493370c12651e", + "476c48e5591c52d8000c65bc88421652" # remove it, match key phrases +] + + +PRETTY_NAMES = { + "arguana": "Arguana", + "climate-fever": "Climate-FEVER", + "dbpedia-entity": "DBPedia", + "fever": "FEVER", + "fiqa": "FiQA", + "hotpotqa": "HotpotQA", + "nfcorpus": "NFCorpus", + "nq": "NQ", + "quora": "Quora", + "scidocs": "SCIDOCS", + "scifact": "SciFact", + "trec-covid": "TREC-COVID", + "webis-touche2020": "Touche-2020" +} + +def read_csv(filename): + data = defaultdict(lambda: defaultdict(dict)) + if not os.path.exists(filename): + print(f"Warning: {filename} does not exist. Skipping this file.") + return data + + with open(filename, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + dataset = row['dataset'].lower() + prompt_hash = row['prompt_hash'] + if prompt_hash in SKIP_OLD_HASHES: + continue + ndcg = float(row['ndcg@10']) if row['ndcg@10'] else None + recall = float(row['recall@100']) if 'recall@100' in row and row['recall@100'] else None + + if prompt_hash == 'none': + data[dataset]['None'] = ndcg + else: + data[dataset]['Prompted'][prompt_hash] = (ndcg, recall) + return data + +def format_value(value): + return f"{value:.1f}" if type(value) == float else "-" + +def calculate_average(values): + real_vals = [v for v in values if v is not None and type(v) == float] + if not len(real_vals): + return None + return statistics.mean(real_vals) + +def get_best_dev_prompt(dataset_name, dev_data, is_model: bool = False): + if not len(dev_data) or not dev_data['Prompted']: + return None, None + + try: + rounded_scores = {k: (f"{v[0]:.1f}", v[1]) for k, v in dev_data['Prompted'].items()} + except Exception: + breakpoint() + max_ndcg = max(v[0] for v in rounded_scores.values()) + + # Filter prompts with the highest NDCG + # prompt can't be None + best_prompts = {k: v for k, v in rounded_scores.items() if v[0] == max_ndcg and k not in ['none', None]} + + + # also print the name of the prompt_hash + # if is_model: + # print(f"Best prompt for {dataset_name} is {best_prompts}") + + + if len(best_prompts) == 1: + return list(best_prompts.keys())[0], max_ndcg + + # If there's a tie, use recall as a tiebreaker + max_recall = max(v[1] or 0 for v in best_prompts.values()) + best_prompts = {k: v for k, v in best_prompts.items() if v[1] == max_recall} + + if len(best_prompts) == 1: + return list(best_prompts.keys())[0], max_ndcg + + # If there's still a tie, choose the prompt with the lowest hash value + best_prompt = min(best_prompts.keys()) + + return best_prompt, max_ndcg + +import csv +from collections import defaultdict +import statistics +import os +import random + +def generate_latex_table(bm25_data, repllama_data, modelname_data): + datasets = list(PRETTY_NAMES.keys()) + latex_rows = [] + + tuned_datasets = { + 'bm25': set(), + 'repllama': set(), + 'modelname': set() + } + + for dataset in datasets: + pretty_name = PRETTY_NAMES[dataset] + dev_dataset = f"{dataset}-dev" + + bm25_prompt, bm25_prompt_score = get_best_dev_prompt(dev_dataset, bm25_data.get(dev_dataset, {})) + repllama_prompt, repllama_prompt_score = get_best_dev_prompt(dev_dataset, repllama_data.get(dev_dataset, {}), True) + modelname_prompt, modelname_prompt_score = get_best_dev_prompt(dev_dataset, modelname_data.get(dev_dataset, {})) + + print(dataset, repllama_prompt) + + bm25_prompt_value = bm25_data[dataset]['Prompted'].get(bm25_prompt, (None, None))[0] if bm25_prompt else None + repllama_prompt_value = repllama_data[dataset]['Prompted'].get(repllama_prompt, (None, None))[0] if repllama_prompt else None + modelname_prompt_value = modelname_data[dataset]['Prompted'].get(modelname_prompt, (None, None))[0] if modelname_prompt else None + + if bm25_prompt_value is not None: + tuned_datasets['bm25'].add(dataset) + if repllama_prompt_value is not None: + tuned_datasets['repllama'].add(dataset) + if modelname_prompt_value is not None: + tuned_datasets['modelname'].add(dataset) + + row = [ + pretty_name, + format_value(bm25_data[dataset]['None']), + format_value(bm25_prompt_value), + format_value(max(v[0] for v in bm25_data[dataset]['Prompted'].values()) if bm25_data[dataset]['Prompted'] else None), + format_value(repllama_data[dataset]['None']), + format_value(repllama_prompt_value), + format_value(max(v[0] for v in repllama_data[dataset]['Prompted'].values()) if repllama_data[dataset]['Prompted'] else None), + format_value(modelname_data[dataset]['None']), + format_value(modelname_prompt_value), + format_value(max(v[0] for v in modelname_data[dataset]['Prompted'].values()) if modelname_data[dataset]['Prompted'] else None) + ] + latex_rows.append(" & ".join(row) + " \\\\") + + averages = [ + "Average", + format_value(calculate_average([bm25_data[d]['None'] for d in datasets])), + "-", + format_value(calculate_average([max(v[0] for v in bm25_data[d]['Prompted'].values()) if bm25_data[d]['Prompted'] else None for d in datasets])), + format_value(calculate_average([repllama_data[d]['None'] for d in datasets])), + "-", + format_value(calculate_average([max(v[0] for v in repllama_data[d]['Prompted'].values()) if repllama_data[d]['Prompted'] else None for d in datasets])), + format_value(calculate_average([modelname_data[d]['None'] for d in datasets])), + "-", + format_value(calculate_average([max(v[0] for v in modelname_data[d]['Prompted'].values()) if modelname_data[d]['Prompted'] else None for d in datasets])) + ] + + averages_tuned = [ + "Average (Tuned)", + format_value(calculate_average([bm25_data[d]['None'] for d in tuned_datasets['bm25']])), + format_value(calculate_average([bm25_data[d]['Prompted'].get(get_best_dev_prompt(f"{d}-dev", bm25_data.get(f"{d}-dev", {}))[0], (None, None))[0] for d in tuned_datasets['bm25']])), + format_value(calculate_average([max(v[0] for v in bm25_data[d]['Prompted'].values()) if bm25_data[d]['Prompted'] else None for d in tuned_datasets['bm25']])), + format_value(calculate_average([repllama_data[d]['None'] for d in tuned_datasets['repllama']])), + format_value(calculate_average([repllama_data[d]['Prompted'].get(get_best_dev_prompt(f"{d}-dev", repllama_data.get(f"{d}-dev", {}))[0], (None, None))[0] for d in tuned_datasets['repllama']])), + format_value(calculate_average([max(v[0] for v in repllama_data[d]['Prompted'].values()) if repllama_data[d]['Prompted'] else None for d in tuned_datasets['repllama']])), + format_value(calculate_average([modelname_data[d]['None'] for d in tuned_datasets['modelname']])), + format_value(calculate_average([modelname_data[d]['Prompted'].get(get_best_dev_prompt(f"{d}-dev", modelname_data.get(f"{d}-dev", {}), True)[0], (None, None))[0] for d in tuned_datasets['modelname']])), + format_value(calculate_average([max(v[0] for v in modelname_data[d]['Prompted'].values()) if modelname_data[d]['Prompted'] else None for d in tuned_datasets['modelname']])) + ] + + latex_table = f""" +\\begin{{table*}}[t] +\\centering +\\begin{{tabular}}{{l|ccc|ccc|ccc}} +\\toprule +\\multirow{{2}}{{*}}{{Dataset}} & \\multicolumn{{3}}{{c|}}{{BM25}} & \\multicolumn{{3}}{{c|}}{{RepLLaMA}} & \\multicolumn{{3}}{{c}}{{\\modelname}} \\\\ +\\cmidrule(l){{2-4}} \\cmidrule(l){{5-7}} \\cmidrule(l){{8-10}} + & None & Prompt & Oracle & None & Prompt & Oracle & None & Prompt & Oracle \\\\ +\\midrule +{chr(10).join(latex_rows)} +\\midrule +{" & ".join(averages_tuned)} \\\\ +{" & ".join(averages)} \\\\ +\\bottomrule +\\end{{tabular}} +\\caption{{Effectiveness of BM25, RepLLaMA, and \\modelname on BEIR datasets. Results are shown for standard retrieval (None), best prompt from dev set (Prompt), and best overall prompt (Oracle). Missing values are indicated by "-".}} +\\label{{tab:beir-results}} +\\end{{table*}} +""" + return latex_table + +# Main execution +bm25_data = read_csv('results/bm25_results.csv') +repllama_data = read_csv('results/reproduced-v2_results.csv') +modelname_data = read_csv('results/joint-full_results.csv') + + +latex_table = generate_latex_table(bm25_data, repllama_data, modelname_data) + +with open('results/final_table.tex', 'w') as f: + f.write(latex_table) + +print("Final table has been generated and saved as 'final_table.tex'") + +def remove_prompt_columns_refined(latex_table): + lines = latex_table.split('\n') + modified_lines = [] + + for line in lines: + if '\\begin{tabular}' in line: + # Update the tabular environment + line = line.replace('{l|ccc|ccc|ccc}', '{l|cc|cc|cc}') + elif '\\multirow{2}{*}{Dataset}' in line: + # The header is already correct, keep it as is + modified_lines.append(line) + elif 'None & Prompt & Oracle' in line: + # Remove 'Prompt' from the subheader + line = line.replace('None & Prompt & Oracle', 'None & Oracle') + elif '\\cmidrule' in line: + # The cmidrule specifications are already correct, keep them as is + modified_lines.append(line) + elif ' & ' in line and not any(keyword in line for keyword in ['\\midrule', '\\bottomrule', 'Average']): + # This is a data row, remove 'Prompt' columns + parts = line.split('&') + new_parts = [parts[0]] + [parts[i] for i in [1, 3, 4, 6, 7, 9]] + line = ' & '.join(new_parts) + elif 'Average' in line: + # Handle average rows + parts = line.split('&') + new_parts = [parts[0]] + [parts[i] for i in [1, 3, 4, 6, 7, 9]] + line = ' & '.join(new_parts) + + modified_lines.append(line) + + return '\n'.join(modified_lines) + +# Apply the modification +modified_latex_table = remove_prompt_columns_refined(latex_table) + +# Write the modified table to a new file +with open('results/final_table_without_prompt_refined.tex', 'w') as f: + f.write(modified_latex_table) + +print("Modified table without 'Prompt' columns has been generated and saved as 'final_table_without_prompt_refined.tex'") \ No newline at end of file diff --git a/scripts/setup/download_dev_sets.py b/scripts/setup/download_dev_sets.py new file mode 100644 index 0000000..3111e80 --- /dev/null +++ b/scripts/setup/download_dev_sets.py @@ -0,0 +1,99 @@ +import ir_datasets +import json +import os +import argparse +import random +from collections import defaultdict + +def process_dataset(dataset_name, output_dir, split='dev'): + print(f"Processing dataset: {dataset_name}") + + # Determine the correct split to use + if dataset_name.lower() == "nq": + # use natural-questions/dev instead + loading_dataset_name = "dpr-w100/natural-questions/dev" + dataset = ir_datasets.load(loading_dataset_name) + else: + try: + dataset = ir_datasets.load(f"beir/{dataset_name.lower()}/{split}") + except KeyError: + print(f"Dev split not found for {dataset_name}, falling back to train split.") + dataset = ir_datasets.load(f"beir/{dataset_name.lower()}/train") + + # Create output directories + queries_dir = os.path.join(output_dir, 'queries') + qrels_dir = os.path.join(output_dir, 'qrels') + os.makedirs(queries_dir, exist_ok=True) + os.makedirs(qrels_dir, exist_ok=True) + + # Save queries + queries_file = os.path.join(queries_dir, f"{dataset_name.lower()}.dev.jsonl") + with open(queries_file, 'w') as f: + for query in dataset.queries_iter(): + json.dump({"query_id": query.query_id, "query": query.text}, f) + f.write('\n') + + # Save qrels + qrels_file = os.path.join(qrels_dir, f"{dataset_name.lower()}.qrels") + if dataset_name.lower() == "nq": + # add "doc" before each docid + with open(qrels_file, 'w') as f: + for qrel in dataset.qrels_iter(): + f.write(f"{qrel.query_id} 0 doc{qrel.doc_id} {qrel.relevance}\n") + else: + with open(qrels_file, 'w') as f: + for qrel in dataset.qrels_iter(): + f.write(f"{qrel.query_id} 0 {qrel.doc_id} {qrel.relevance}\n") + + print(f"Saved queries to {queries_file}") + print(f"Saved qrels to {qrels_file}") + + return queries_file + +def sample_queries(file, output_file, n): + all_queries = defaultdict(list) + query_ids_sampled = set() + + dataset_name = os.path.basename(file).split('.')[0] + with open(file, 'r') as f: + queries = [json.loads(line) for line in f] + all_queries[dataset_name] = random.sample(queries, min(n, len(queries))) + + with open(output_file, 'w') as f: + for dataset, queries in all_queries.items(): + for query in queries: + query['dataset'] = dataset + json.dump(query, f) + query_ids_sampled.add(query['query_id']) + f.write('\n') + + print(f"Sampled queries saved to {output_file}") + + # now load and save the qrels for only the sampled queries + qrels_dir = os.path.dirname(file).replace("/queries", "/qrels") + with open(os.path.join(qrels_dir, f"{dataset_name}.qrels"), 'r') as f: + qrels = [line.split() for line in f] + qrels_sampled = [qrel for qrel in qrels if qrel[0] in query_ids_sampled] + with open(os.path.join(qrels_dir, f"{dataset_name}.qrels.sampled"), 'w') as f: + for qrel in qrels_sampled: + f.write(' '.join(qrel) + '\n') + +def main(): + parser = argparse.ArgumentParser(description="Process BEIR datasets and sample queries.") + parser.add_argument("output_dir", help="Directory to save output files") + parser.add_argument("--sample_size", type=int, default=10, help="Number of queries to sample from each dataset") + args = parser.parse_args() + + datasets = ['arguana'] + processed_files = [] + + for dataset in datasets: + processed_files.append(process_dataset(dataset, args.output_dir)) + + sample_output = "resources/beir" + sample_file = os.path.join(sample_output, f'{dataset.lower()}.dev.jsonl') + sample_queries(processed_files[-1], sample_file, args.sample_size) + +if __name__ == "__main__": + main() + # python scripts/download_dev_sets.py resources/downloaded \ No newline at end of file diff --git a/scripts/setup/install_conda.sh b/scripts/setup/install_conda.sh new file mode 100644 index 0000000..027bed2 --- /dev/null +++ b/scripts/setup/install_conda.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +cd ~/ +# if the file doesn't exist, download it +if [ ! -f ./Anaconda3-5.1.0-Linux-x86_64.sh ]; then + wget https://repo.anaconda.com/archive/Anaconda3-5.1.0-Linux-x86_64.sh +fi +# install manually +# echo ". /home/ubuntu/anaconda3/etc/profile.d/conda.sh" >> ~/.bashrc diff --git a/scripts/setup/install_req.sh b/scripts/setup/install_req.sh new file mode 100644 index 0000000..38293d7 --- /dev/null +++ b/scripts/setup/install_req.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# conda config --set ssl_verify false +# conda create -n tevatron python=3.10 -y +# conda activate tevatron +# git config --global user.email "wellerorion@gmail.com" +# git config --global user.name "Orion Weller" + +pip install deepspeed accelerate +pip install transformers datasets peft +pip install faiss-cpu +pip install -r requirements.txt +pip install -e . + + +git config --global credential.helper store +# huggingface-cli login --token $TOKEN --add-to-git-credential +# conda install -c conda-forge openjdk=11 -y diff --git a/scripts/tevatron/hn_mining.py b/scripts/tevatron/hn_mining.py new file mode 100644 index 0000000..b446058 --- /dev/null +++ b/scripts/tevatron/hn_mining.py @@ -0,0 +1,112 @@ +import json +from argparse import ArgumentParser +from datasets import load_dataset, concatenate_datasets +from multiprocessing import Manager +from tqdm import tqdm +from pyserini.eval.evaluate_dpr_retrieval import SimpleTokenizer, has_answers + + +class BasicHardNegativeMiner: + def __init__(self, results_path, corpus_dataset, depth): + self.corpus_data = corpus_dataset + self.depth=depth + manager = Manager() + self.retrieval_results = manager.dict(self._read_result(results_path)) + self.docid_to_idx = manager.dict({k: v for v, k in enumerate(self.corpus_data['docid'])}) + + @staticmethod + def _read_result(path): + retrieval_results = {} + with open(path) as f: + for line in f: + qid, pid, _ = line.rstrip().split() + if qid not in retrieval_results: + retrieval_results[qid] = [] + retrieval_results[qid].append(pid) + return retrieval_results + + def __call__(self, example): + query_id = example['query_id'] + retrieved_docid = self.retrieval_results[query_id] + positive_ids = [pos['docid'] for pos in example['positive_passages']] + hard_negatives = [] + for docid in retrieved_docid[:self.depth]: + doc_info = self.corpus_data[self.docid_to_idx[docid]] + text = doc_info['text'] + title = doc_info['title'] if 'title' in doc_info else None + if docid not in positive_ids: + hn_doc = {'docid': docid, 'text': text} + if title: + hn_doc['title'] = title + hard_negatives.append(hn_doc) + example['negative_passages'] = hard_negatives + return example + + +class EMHardNegativeMiner(BasicHardNegativeMiner): + def __init__(self, results_path, corpus_dataset, depth, tokenzier, regex=False): + self.tokenizer = tokenzier + self.regex = regex + super().__init__(results_path, corpus_dataset, depth) + + def __call__(self, example): + query_id = example['query_id'] + retrieved_docid = self.retrieval_results[query_id] + answers = example['answers'] + positives = [] + hard_negatives = [] + for docid in retrieved_docid[:self.depth]: + doc_info = self.corpus_data[self.docid_to_idx[docid]] + text = doc_info['text'] + title = doc_info['title'] if 'title' in doc_info else None + if not has_answers(text, answers, self.tokenizer, self.regex): + hn_doc = {'docid': docid, 'text': text} + if title: + hn_doc['title'] = title + hard_negatives.append(hn_doc) + else: + pos_doc = {'docid': docid, 'text': text} + if title: + pos_doc['title'] = title + positives.append(pos_doc) + example['negative_passages'] = hard_negatives + example['positive_passages'] = positives + return example + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--train_data_name', type=str, required=True) + parser.add_argument('--corpus_data_name', type=str, required=True) + parser.add_argument('--result_path', type=str, required=True) + parser.add_argument('--depth', type=int, default=100, required=False) + parser.add_argument('--min_hn', type=int, default=1, required=False) + parser.add_argument('--output', type=str, required=True) + parser.add_argument('--cache_dir', type=str, required=False) + parser.add_argument('--proc_num', type=int, default=12, required=False) + parser.add_argument('--em', action='store_true', required=False) + parser.add_argument('--regex', action='store_true', required=False) + + args = parser.parse_args() + train_data = load_dataset(args.train_data_name, cache_dir=args.cache_dir)['train'] + corpus_data = load_dataset(args.corpus_data_name, cache_dir=args.cache_dir)['train'] + if args.em: + miner = EMHardNegativeMiner(args.result_path, corpus_data, args.depth, SimpleTokenizer(), regex=args.regex) + else: + miner = BasicHardNegativeMiner(args.result_path, corpus_data, args.depth) + + hn_data = train_data.map( + miner, + batched=False, + num_proc=args.proc_num, + desc="Running hard negative mining", + ) + + combined_data = concatenate_datasets([train_data, hn_data]) + combined_data = combined_data.filter( + function=lambda data: len(data["positive_passages"]) >= 1 and len(data["negative_passages"]) >= args.min_hn + ) + + with open(args.output, 'w') as f: + for e in tqdm(combined_data): + f.write(json.dumps(e, ensure_ascii=False)+'\n') diff --git a/scripts/tevatron/reduce_results.py b/scripts/tevatron/reduce_results.py new file mode 100644 index 0000000..d0c1a26 --- /dev/null +++ b/scripts/tevatron/reduce_results.py @@ -0,0 +1,28 @@ +import argparse +import os + +parser = argparse.ArgumentParser(description='Reduce retrieval results from multiple shards.') +parser.add_argument('--results_dir', type=str, help='Directory that contains results from all shards', required=True) +parser.add_argument('--output', help='Path to final results file', required=True) +parser.add_argument('--depth', type=int, help='Number of retrieved doc for each query', required=False, default=100) +args = parser.parse_args() + + +all_results = {} +print(f'Merging results from {len(os.listdir(args.results_dir))} result files.') +for filename in os.listdir(args.results_dir): + path = os.path.join(args.results_dir, filename) + with open(path) as f: + for line in f: + qid, docid, score = line.split() + score = float(score) + if qid not in all_results: + all_results[qid] = [] + all_results[qid].append((docid, score)) + +with open(args.output, 'w') as f: + print(f'Writing output to {args.output} with depth {args.depth}') + for qid in all_results: + results = sorted(all_results[qid], key=lambda x: x[1], reverse=True)[:args.depth] + for docid, score in results: + f.write(f'{qid}\t{docid}\t{score}\n') diff --git a/scripts/training/train.sh b/scripts/training/train.sh new file mode 100644 index 0000000..c396547 --- /dev/null +++ b/scripts/training/train.sh @@ -0,0 +1,32 @@ +#!/bin/bash +deepspeed --include localhost:0,1,2,3 --master_port 60000 --module tevatron.retriever.driver.train \ + --deepspeed deepspeed/ds_zero3_config.json \ + --output_dir retriever-llama2-4gpu \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --lora \ + --lora_r 32 \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 200 \ + --dataset_name Tevatron/msmarco-passage-aug \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 8 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 32 \ + --passage_max_len 196 \ + --num_train_epochs 1 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --warmup_steps 100 \ + --gradient_accumulation_steps 4 + + + + # > repro.log 2>&1 | tee repro.log \ No newline at end of file diff --git a/scripts/training/train_instruct.sh b/scripts/training/train_instruct.sh new file mode 100644 index 0000000..624fbfd --- /dev/null +++ b/scripts/training/train_instruct.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# args are (1) name of run (2) dataset, and (3) nodes e.g. "0,1,2,3" (4) port num either 1 or 2 or something +# bash train_instruct.sh standard orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 2 > standard-percent-4gpu.log 2>&1 + +# bash train_instruct.sh old_standard orionweller/instruction-msmarco-passage-aug-50-percent "4,5,6,7" 2 > old_standard-percent-4gpu.log +# bash train_instruct.sh standard_fixed orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 0 > generic-4gpu-long.log + +echo "Args are $1 $2 $3 $4" +deepspeed --include localhost:$3 --master_port "6000$4" --module tevatron.retriever.driver.train \ + --deepspeed deepspeed/ds_zero3_config.json \ + --output_dir retriever-llama2-instruct-$1 \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --lora \ + --lora_r 32 \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 500 \ + --dataset_name $2 \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 8 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 304 \ + --passage_max_len 196 \ + --num_train_epochs 1 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --warmup_steps 100 \ + --gradient_accumulation_steps 4 \ + --negatives_first_n 3 + # --dont_shuffle + \ No newline at end of file diff --git a/scripts/training/train_instruct_llama3.sh b/scripts/training/train_instruct_llama3.sh new file mode 100644 index 0000000..e215fa2 --- /dev/null +++ b/scripts/training/train_instruct_llama3.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# args are (1) name of run (2) dataset, and (3) nodes e.g. "0,1,2,3" (4) port num either 1 or 2 or something +# bash train_instruct.sh standard orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 2 > standard-percent-4gpu.log 2>&1 + +# bash train_instruct.sh old_standard orionweller/instruction-msmarco-passage-aug-50-percent "4,5,6,7" 2 > old_standard-percent-4gpu.log +# bash train_instruct.sh standard_fixed orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 0 > generic-4gpu-long.log + +echo "Args are $1 $2 $3 $4" +deepspeed --include localhost:$3 --master_port "6000$4" --module tevatron.retriever.driver.train \ + --deepspeed deepspeed/ds_zero3_config.json \ + --output_dir retriever-llama3-$1 \ + --model_name_or_path meta-llama/Meta-Llama-3.1-8B \ + --lora \ + --lora_r 32 \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 500 \ + --dataset_name $2 \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 8 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 304 \ + --passage_max_len 196 \ + --num_train_epochs 1 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --warmup_steps 100 \ + --gradient_accumulation_steps 4 \ + --negatives_first_n 3 + # --dont_shuffle + \ No newline at end of file diff --git a/scripts/training/train_instruct_llama3_instruct.sh b/scripts/training/train_instruct_llama3_instruct.sh new file mode 100644 index 0000000..5c17d7b --- /dev/null +++ b/scripts/training/train_instruct_llama3_instruct.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# args are (1) name of run (2) dataset, and (3) nodes e.g. "0,1,2,3" (4) port num either 1 or 2 or something +# bash train_instruct.sh standard orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 2 > standard-percent-4gpu.log 2>&1 + +# bash train_instruct.sh old_standard orionweller/instruction-msmarco-passage-aug-50-percent "4,5,6,7" 2 > old_standard-percent-4gpu.log +# bash train_instruct.sh standard_fixed orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 0 > generic-4gpu-long.log + +echo "Args are $1 $2 $3 $4" +deepspeed --include localhost:$3 --master_port "6000$4" --module tevatron.retriever.driver.train \ + --deepspeed deepspeed/ds_zero3_config.json \ + --output_dir retriever-llama3-instruct-$1 \ + --model_name_or_path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --lora \ + --lora_r 32 \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 500 \ + --dataset_name $2 \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 8 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 304 \ + --passage_max_len 196 \ + --num_train_epochs 1 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --warmup_steps 100 \ + --gradient_accumulation_steps 4 \ + --negatives_first_n 3 + # --dont_shuffle + \ No newline at end of file diff --git a/scripts/training/train_instruct_mistral.sh b/scripts/training/train_instruct_mistral.sh new file mode 100644 index 0000000..0b9ea82 --- /dev/null +++ b/scripts/training/train_instruct_mistral.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# args are (1) name of run (2) dataset, and (3) nodes e.g. "0,1,2,3" (4) port num either 1 or 2 or something +# bash train_instruct.sh standard orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 2 > standard-percent-4gpu.log 2>&1 + +# bash train_instruct.sh old_standard orionweller/instruction-msmarco-passage-aug-50-percent "4,5,6,7" 2 > old_standard-percent-4gpu.log +# bash train_instruct.sh standard_fixed orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 0 > generic-4gpu-long.log + +echo "Args are $1 $2 $3 $4" +deepspeed --include localhost:$3 --master_port "6000$4" --module tevatron.retriever.driver.train \ + --deepspeed deepspeed/ds_zero3_config.json \ + --output_dir retriever-mistral-v1-$1 \ + --model_name_or_path mistralai/Mistral-7B-v0.3 \ + --lora \ + --lora_r 32 \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 500 \ + --dataset_name $2 \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 8 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 304 \ + --passage_max_len 196 \ + --num_train_epochs 1 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --warmup_steps 100 \ + --gradient_accumulation_steps 4 \ + --negatives_first_n 3 + # --dont_shuffle + \ No newline at end of file diff --git a/scripts/training/train_instruct_mistral_v1.sh b/scripts/training/train_instruct_mistral_v1.sh new file mode 100644 index 0000000..184d6c5 --- /dev/null +++ b/scripts/training/train_instruct_mistral_v1.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# args are (1) name of run (2) dataset, and (3) nodes e.g. "0,1,2,3" (4) port num either 1 or 2 or something +# bash train_instruct.sh standard orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 2 > standard-percent-4gpu.log 2>&1 + +# bash train_instruct.sh old_standard orionweller/instruction-msmarco-passage-aug-50-percent "4,5,6,7" 2 > old_standard-percent-4gpu.log +# bash train_instruct.sh standard_fixed orionweller/instruction-msmarco-passage-aug-50-fixed-standard "0,1,2,3" 0 > generic-4gpu-long.log + +echo "Args are $1 $2 $3 $4" +deepspeed --include localhost:$3 --master_port "6000$4" --module tevatron.retriever.driver.train \ + --deepspeed deepspeed/ds_zero3_config.json \ + --output_dir retriever-mistral-$1 \ + --model_name_or_path mistralai/Mistral-7B-v0.1 \ + --lora \ + --lora_r 32 \ + --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ + --save_steps 500 \ + --dataset_name $2 \ + --query_prefix "query: " \ + --passage_prefix "passage: " \ + --bf16 \ + --pooling eos \ + --append_eos_token \ + --normalize \ + --temperature 0.01 \ + --per_device_train_batch_size 8 \ + --gradient_checkpointing \ + --train_group_size 16 \ + --learning_rate 1e-4 \ + --query_max_len 304 \ + --passage_max_len 196 \ + --num_train_epochs 1 \ + --logging_steps 10 \ + --overwrite_output_dir \ + --warmup_steps 100 \ + --gradient_accumulation_steps 4 \ + --negatives_first_n 3 + # --dont_shuffle + \ No newline at end of file diff --git a/scripts/utils/symlink_dev.sh b/scripts/utils/symlink_dev.sh new file mode 100644 index 0000000..b6c4f1c --- /dev/null +++ b/scripts/utils/symlink_dev.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +DATASETS=( + 'nq-dev' + 'scifact-dev' + 'fiqa-dev' + 'hotpotqa-dev' + 'dbpedia-entity-dev' + 'quora-dev' + 'fever-dev' + 'nfcorpus-dev' +) + +for file_path in llama3.1-instruct llama3.1; do # reproduced-v2 joint-full + for dataset in "${DATASETS[@]}"; do + echo "Symlinking ${dataset}..." + short_dataset=$(echo $dataset | sed 's/-dev//') + echo "bash ./scripts/symlink_msmarco.sh $file_path/$short_dataset $file_path/$dataset" + bash ./scripts/symlink_msmarco.sh $file_path/$short_dataset $file_path/$dataset + done +done \ No newline at end of file diff --git a/scripts/utils/symlink_msmarco.sh b/scripts/utils/symlink_msmarco.sh new file mode 100644 index 0000000..d3a85d1 --- /dev/null +++ b/scripts/utils/symlink_msmarco.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Check if both input and output directories are provided +if [ $# -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +input_dir="$1" +output_dir="$2" + +# Check if input directory exists +if [ ! -d "$input_dir" ]; then + echo "Error: Input directory '$input_dir' does not exist." + exit 1 +fi + +# Create output directory if it doesn't exist +mkdir -p "$output_dir" + +# Create symbolic links for corpus_emb.*.pkl files +for file in "$input_dir"/corpus_emb.*.pkl; do + if [ -f "$file" ]; then + basename=$(basename "$file") + # Get the absolute path of the input file + absolute_path=$(readlink -f "$file") + # Create a relative path from the output directory to the input file + relative_path=$(realpath --relative-to="$output_dir" "$absolute_path") + ln -sf "$relative_path" "$output_dir/$basename" + echo "Created symlink: $output_dir/$basename -> $relative_path" + fi +done + +echo "Symlinking complete." \ No newline at end of file diff --git a/scripts/utils/upload_to_hf.py b/scripts/utils/upload_to_hf.py new file mode 100644 index 0000000..6a1634d --- /dev/null +++ b/scripts/utils/upload_to_hf.py @@ -0,0 +1,51 @@ +import huggingface_hub +import os +import argparse + + +def upload_folder(args): + print(f"Creating a new repo {args.repo}") + api = huggingface_hub.HfApi() + if not args.skip_create: + repo_url = api.create_repo( + args.repo, + repo_type="model", + exist_ok=False, + private=True + ) + # Upload all the content from the local folder to your remote Space. + # By default, files are uploaded at the root of the repo + print(f"Uploading {args.folder} to {args.repo}") + try: + api.upload_folder( + folder_path=args.folder, + repo_id=args.repo, + repo_type="model", + multi_commits=True, + multi_commits_verbose=True, + ) + except Exception as e: + print(e) + import time + time.sleep(30) + print(f"Error Uploading {args.folder} to {args.repo}, trying again") + api.upload_folder( + folder_path=args.folder, + repo_id=args.repo, + repo_type="model", + multi_commits=True, + multi_commits_verbose=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload a folder to Hugging Face Hub") + parser.add_argument("-f", "--folder", type=str, help="The folder to upload", required=True) + parser.add_argument("-r", "--repo", type=str, help="The repo to upload to", required=True) + parser.add_argument("--skip_create", action="store_true", help="Skip creating") + args = parser.parse_args() + upload_folder(args) + + + # example usage: + # python scripts/upload_to_hf.py -f retriever-llama2-4gpu/final -r orionweller/repllama-reproduced-v2 \ No newline at end of file diff --git a/scripts/utils/upload_to_hf_all.py b/scripts/utils/upload_to_hf_all.py new file mode 100644 index 0000000..82eb09e --- /dev/null +++ b/scripts/utils/upload_to_hf_all.py @@ -0,0 +1,78 @@ +import huggingface_hub +import os +import argparse +import shutil + +def create_final_folder(source_dir): + final_dir = os.path.join(source_dir, "final") + os.makedirs(final_dir, exist_ok=True) + + files_moved = [] + for item in os.listdir(source_dir): + item_path = os.path.join(source_dir, item) + if os.path.isfile(item_path): + shutil.copy2(item_path, final_dir) + files_moved.append(item) + + return files_moved + + +def upload_folder(args): + print(f"Creating a new repo {args.repo}") + api = huggingface_hub.HfApi() + if not args.skip_create: + repo_url = api.create_repo( + args.repo, + repo_type="model", + exist_ok=False, + private=True + ) + + # Create the 'final' folder and copy non-folder files into it + # if there are files in the root of the folder but not folders + files_not_folders = [f for f in os.listdir(args.folder) if os.path.isfile(os.path.join(args.folder, f))] + if files_not_folders: + files_moved = create_final_folder(args.folder) + print(f"Files moved to final folder: {files_moved}") + else: + files_moved = None + + print(f"Uploading folder structure from {args.folder} to {args.repo}") + try: + api.upload_folder( + folder_path=args.folder, + repo_id=args.repo, + repo_type="model", + ignore_patterns=files_moved, # Ignore all files in the root + multi_commits=True, + multi_commits_verbose=True, + ) + except Exception as e: + print(e) + import time + time.sleep(30) + print(f"Error Uploading {args.folder} to {args.repo}, trying again") + api.upload_folder( + folder_path=args.folder, + repo_id=args.repo, + repo_type="model", + ignore_patterns=files_moved, # Ignore all files in the root + multi_commits=True, + multi_commits_verbose=True, + + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload a folder structure to Hugging Face Hub") + parser.add_argument("-f", "--folder", type=str, help="The folder to upload", required=True) + parser.add_argument("-r", "--repo", type=str, help="The repo to upload to", required=True) + parser.add_argument("--skip_create", action="store_true", help="Skip creating the repository") + args = parser.parse_args() + upload_folder(args) + +# Example usage: +# python scripts/upload_to_hf_all.py -f /home/ubuntu/LLaMA-Factory/followir-samaya -r orionweller/followir-samaya +# python scripts/upload_to_hf_all.py -f /home/ubuntu/LLaMA-Factory/followir-samaya-exported -r orionweller/followir-samaya-full + + +# python scripts/upload_to_hf_all.py -f retriever-llama3-full -r orionweller/followir-joint-llama3.1 \ No newline at end of file diff --git a/scripts/utils/validate_all_present.py b/scripts/utils/validate_all_present.py new file mode 100644 index 0000000..b0b8af0 --- /dev/null +++ b/scripts/utils/validate_all_present.py @@ -0,0 +1,103 @@ +import pandas as pd +import hashlib +import argparse +import sys +import json +from collections import defaultdict + +def calculate_md5(text): + return hashlib.md5(text.encode('utf-8')).hexdigest() + +def load_and_hash_csv(filename, column_name): + try: + df = pd.read_csv(filename, header=None, index_col=None) + except Exception: # load as newlines + df = pd.read_csv(filename, header=None, index_col=None, sep='\t') + # set the columns name to prompt if there is only one else dataset,prompt + df.columns = ['prompt'] if len(df.columns) == 1 else ['dataset', 'prompt'] + # save json of hash to text + df['prompt_hash'] = df[column_name].apply(calculate_md5) + df.to_csv(f"results/{filename}_hashes.csv", index=False) + return set(df[column_name].apply(calculate_md5)) + +def validate_hashes(data_csv, generic_hashes, domain_hashes): + df = pd.read_csv(f"results/{data_csv}_results.csv") + + dataset_hashes = defaultdict(set) + for _, row in df.iterrows(): + if pd.notna(row['prompt_hash']) and pd.notna(row['dataset']): + dataset_hashes[row['dataset']].add(row['prompt_hash']) + + print(f"Validation results for {data_csv}:") + + for dataset, hashes in dataset_hashes.items(): + + + missing_generic = generic_hashes - hashes + if missing_generic: + print(f"\nDataset: {dataset}") + print(f"Total hashes in dataset: {len(hashes)}") + print(f"Missing generic hashes: {len(missing_generic)}") + print(missing_generic) + + if dataset in domain_hashes: # not every dataset has domain-specific, e.g. dev sets + missing_domain = domain_hashes[dataset] - hashes + if missing_domain: + print(f"\nDataset: {dataset}") + print(f"Total hashes in dataset: {len(hashes)}") + print(f"Missing domain hashes: {len(missing_domain)}") + print(missing_domain) + + # if not missing_generic and (dataset not in domain_hashes or not missing_domain): + # print(f"All expected hashes present for dataset {dataset}.") + + # Check for datasets in domain_hashes that are not in the results + missing_datasets = set(domain_hashes.keys()) - set(dataset_hashes.keys()) + if missing_datasets: + print("\nDatasets missing from results:") + print(missing_datasets) + +def main(): + parser = argparse.ArgumentParser(description="Validate CSV file hashes against generic and domain CSV files.") + parser.add_argument("file_to_validate", help="The CSV file to validate") + parser.add_argument("--generic", default="generic_prompts.csv", help="Path to the generic CSV file (default: generic_prompts.csv)") + parser.add_argument("--domain", default="domain_prompts.csv", help="Path to the domain CSV file (default: domain_prompts.csv)") + + args = parser.parse_args() + + try: + # Load and hash the generic CSV file + generic_hashes = load_and_hash_csv(args.generic, 'prompt') + + # Load and hash the domain CSV file + domain_df = pd.read_csv(args.domain, header=None, index_col=None) + domain_df.columns = ["dataset", "prompt"] + domain_hashes = defaultdict(set) + hash_map = {} + for _, row in domain_df.iterrows(): + domain_hashes[row['dataset']].add(calculate_md5(row['prompt'])) + hash_map[calculate_md5(row['prompt'])] = row['prompt'] + + # save hash map + with open("results/hash_map_domain.json", "w") as f: + json.dump(hash_map, f) + + # Validate the data CSV file + validate_hashes(args.file_to_validate, generic_hashes, domain_hashes) + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + except pd.errors.EmptyDataError: + print(f"Error: The file {args.file_to_validate} is empty.", file=sys.stderr) + sys.exit(1) + except Exception as e: + import traceback + print(traceback.format_exc()) + print(f"An unexpected error occurred: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + main() + + # python scripts/validate_all_present.py joint-full + # python scripts/validate_all_present.py bm25 \ No newline at end of file