-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
EC2 Default User
committed
Nov 15, 2023
1 parent
2f80814
commit 7137315
Showing
13 changed files
with
419 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# customized | ||
experiments-full-t5seq-aq/ | ||
wandb/ | ||
data/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,71 @@ | ||
# Package installation | ||
pip install -r requirement.txt | ||
pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html | ||
conda install -c conda-forge faiss-cpu | ||
# Scalable and Effective Generative Information Retrieval | ||
This repo provides the source code and checkpoints for our paper [Scalable and Effective Generative Information Retrieval]() (RIPOR) | ||
|
||
## Package installation | ||
- pip install -r requirement.txt | ||
- pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html | ||
- conda install -c conda-forge faiss-gpu | ||
|
||
## Inference | ||
We use 4 A100 GPUs to run the model. It takes rougly 20 mins to do preprocessing and 90 mins for whole evaluation. You can use other types of GPUS like V100, but might take longer time. | ||
``` | ||
bash full_scripts/full_evaluate_t5seq_aq_encoder.sh | ||
``` | ||
|
||
## Training | ||
Our framework contains multiple training phases (see details from Figure 2 in the paper). You can train it sequentially from the starting phase or we provide the checkpoint for each phase that you can directly use it for the subsequent phases. | ||
|
||
### Phase 1: Relevance-Based DocID initialization ($M^0$) | ||
You will start from `t5-base` and obtain the model `$M^0$` after this phase. This phase treat the T5 model as a dense encoder, and we use the two-stage training strategy to train it. In first stage, we use the BM25 negatives. You should run the following script to train the model: | ||
``` | ||
bash full_scripts/full_train_t5seq_encoder_0.sh | ||
``` | ||
Run the following script for the second stage training: | ||
``` | ||
bash full_scripts/full_train_t5seq_encoder_1.sh | ||
``` | ||
Now, you obtain the model $M^0$. Congrats! Let's use the $M^0$ to get the DocID for each document. Before running the script `full_scripts/full_evaluate_t5seq_aq_encoder.sh`, you should change the `task` variable in line 3 as `task=all_aq_pipline`. After that, you run this script: | ||
``` | ||
bash full_scripts/full_evaluate_t5seq_aq_encoder.sh | ||
``` | ||
### Phase 2: Seq2Seq Pretraining + Initial Fine-tuning ($M^2$) | ||
You will start from $M^0$ and obtain $M^2$ after this phase | ||
#### If you skip the phase 1 | ||
Download all files from folder `experiments-full-t5seq-aq/t5_docid_gen_encoder_1` in which it contains training files and checkpoint you need for this phase. | ||
Run the script: | ||
``` | ||
bash full_scripts/full_train_t5seq_seq2seq_0_1_pipeline.sh | ||
``` | ||
#### If you train the $M^0$ by yourself in phase 1 | ||
You should create your own training set with the following procedure: | ||
- Change the `task` variable in line 3 as `task=retrieve_train_queries` in script `full_scripts/full_evaluate_t5seq_aq_encoder.sh`. Then run the script: | ||
``` | ||
full_scripts/full_evaluate_t5seq_aq_encoder.sh | ||
``` | ||
- Use the teacher model (cross-encoder) to rerank the obtained run.json file | ||
``` | ||
full_scripts/rerank_for_create_trainset.sh | ||
``` | ||
- Add the qrel (relevant docid) for training set | ||
``` | ||
python t5_pretrainer/aq_preprocess/add_qrel_to_rerank_run.py | ||
``` | ||
- Then run the script: | ||
``` | ||
bash full_scripts/full_train_t5seq_seq2seq_0_1_pipeline.sh | ||
``` | ||
### Phase 3: Prefix-Oriented Ranking Optimization ($M^3$) | ||
#### If you skip the phase 1 and phase 2 | ||
Download all files from `experiments-full-t5seq-aq/t5_docid_gen_encoder_1` and `experiments-full-t5seq-aq/t5seq_aq_encoder_seq2seq_1`, they provide you with checkpoints, training data, and initialized DocIDs. You start from $M^2$ and obtain the checkpoint $M^3$ after that. Run the script: | ||
``` | ||
bash full_scripts/full_lng_knp_train_pipline.sh | ||
``` | ||
#### If you do not skip the phase 1 and phase 2 | ||
You are a hard-working person that train all models by yourself. You are only one step away from success! But be patient, it might take some time. Since we build the DocID by ourselves, we should generate our own training data. Follow the following procedures for data generation. | ||
- Apply the constrained beam search on $M^2$ to generate data for different prefix length: | ||
Change the `task` variable in line 3 as `t5seq_aq_get_qid_to_smtid_rankdata` in script `full_scripts/full_evaluate_t5seq_aq_encoder.sh`. Then run the script: | ||
``` | ||
full_scripts/full_evaluate_t5seq_aq_encoder.sh | ||
``` | ||
- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
|
||
data_root_dir=./data/msmarco-full | ||
collection_path=$data_root_dir/full_collection/ | ||
queries_path=./data/msmarco-full/all_train_queries/train_queries | ||
|
||
# model dir | ||
experiment_dir=experiments-full-t5seq-aq | ||
pretrained_path=t5-base | ||
|
||
# train_examples path | ||
teacher_score_path=./data/msmarco-full/bm25_run/qrel_added_qid_docids_teacher_scores.train.json | ||
run_name=t5_docid_gen_encoder_0 | ||
output_dir="./$experiment_dir/" | ||
|
||
python -m torch.distributed.launch --nproc_per_node=8 -m t5_pretrainer.main \ | ||
--epochs=50 \ | ||
--run_name=$run_name \ | ||
--learning_rate=1e-4 \ | ||
--loss_type=t5seq_pretrain_margin_mse \ | ||
--model_name_or_path=t5-base \ | ||
--model_type=t5_docid_gen_encoder \ | ||
--teacher_score_path=$teacher_score_path \ | ||
--output_dir=$output_dir \ | ||
--task_names='["rank"]' \ | ||
--wandb_project_name=full_t5seq_encoder \ | ||
--use_fp16 \ | ||
--collection_path=$collection_path \ | ||
--max_length=128 \ | ||
--per_device_train_batch_size=64 \ | ||
--queries_path=$queries_path \ | ||
--pretrained_path=$pretrained_path |
8 changes: 4 additions & 4 deletions
8
full_scripts/full_train_t5seq_encoder.sh → full_scripts/full_train_t5seq_encoder_1.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.