Skip to content

Commit

Permalink
Upload Instruct-Distill
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnweiwei committed Oct 31, 2023
1 parent f646a97 commit e091fd1
Show file tree
Hide file tree
Showing 6 changed files with 1,147 additions and 0 deletions.
81 changes: 81 additions & 0 deletions InstructDistill/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Instruction Distillation

Code for paper [Instruction Distillation Makes LLMs Efficient Pointwise Rankers]()

This project aims to improve the efficiency of LLMs as rankers via instruction distillation.

## Pre-trained Models

| Model | Link |
| ---- | ---- |
| Rank-Flan-T5-XL | |
| Rank-Flan-T5-Large | |
| Rank-Flan-T5-Base | |
| Rank-LLaMA-2-7B | |

The following code show how to predict the relevance of a paired (query, passage).

```python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch

query = "How much impact do masks have on preventing the spread of the COVID-19?"
passage = "Title: Universal Masking is Urgent in the COVID-19 Pandemic: SEIR and Agent Based Models, Empirical Validation, Policy Recommendations Content: We present two models for the COVID-19 pandemic predicting the impact of universal face mask wearing upon the spread of the SARS-CoV-2 virus--one employing a stochastic dynamic network based compartmental SEIR (susceptible-exposed-infectious-recovered) approach, and the other employing individual ABM (agent-based modelling) Monte Carlo simulation--indicating (1) significant impact under (near) universal masking when at least 80% of a population is wearing masks, versus minimal impact when only 50% or less of the population is wearing masks, and (2) significant impact when universal masking is adopted early, by Day 50 of a regional outbreak, versus minimal impact when universal masking is adopted late. These effects hold even at the lower filtering rates of homemade masks. To validate these theoretical models, we compare their predictions against a new empirical data set we have collected"
instrcution = "Predict whether the given passage answer the question.\n\nQuestion: {0}\n\nPassage: {1}\n\nDoes the passage answer the question?"
instrcution = instrcution.format(query, passage)
```
Use case of flan-t5 models
```python
tokenizer = AutoTokenizer.from_pretrained("fireballoon/rank-flan-t5-xl")
model = AutoModelForSeq2SeqLM.from_pretrained("fireballoon/rank-flan-t5-xl", torch_dtype=torch.float16)
token_of_Yes = 2163
features = tokenizer([instrcution,], padding=True, truncation=True, return_tensors="pt", max_length=1024)
features['decoder_input_ids'] = torch.zeros(len(batch), 1).long()
scores = model(**features).logits[:, -1, token_of_Yes]
```
Use case of llama models
```python
tokenizer = AutoTokenizer.from_pretrained("fireballoon/rank-llama-2-7b", use_fast=False, padding_side="left")
model = AutoModelForCausalLM.from_pretrained("fireballoon/rank-llama-2-7b", torch_dtype=torch.float16)
token_of_Yes = 3869
features = tokenizer([instrcution,], padding=True, truncation=True, return_tensors="pt", max_length=1024)
scores = model(**features).logits[:, -1, token_of_Yes]
```

## Training
Retrieve passage using BM25
```
python bm25_retrieval.py
```
(optional) Evaluating Pairwise Ranking Prompting (PRP) on benchmarks.
```
python pairwise_ranking.py --model google/flan-t5-xl --eval true --generate false
```
Getting predictions of PRP on MS MARCO (`data/marco-train-10k.jsonl`, can be downloaded from [RankGPT](https://github.com/sunnweiwei/RankGPT/tree/main#download-data-and-model)). The ranking results will be saved at `out/marco-train-10k-flan-xl.json`.
```
python pairwise_ranking.py \
--model google/flan-t5-xl \
--eval false \
--generate true \
--data data/marco-train-10k.jsonl \
--save_path out/marco-train-10k-flan-xl.json
```
Training the pointwise ranker using PRP's predictions. The model checkpoints well be saved at `out/rank-flan-t5-xl`.
```
python instruction_distill.py \
--model google/flan-t5-xl \
--loss rank_net \
--data data/marco-train-10k.jsonl \
--save_path out/rank-flan-t5-xl \
--permutation out/marco-train-10k-flan-xl.json \
--do_train true \
--do_eval false
```
Converting deepspeed checkpoint.
```
python
```
###



167 changes: 167 additions & 0 deletions InstructDistill/bm25_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
THE_RESULTS = {
'dl19': 'data/rank_results/dl19.json',
'dl20': 'data/rank_results/dl20.json',
'covid': 'data/rank_results/beir-trec-covid.json',
'arguana': 'data/rank_results/beir-arguana.json',
'touche': 'data/rank_results/beir-touche.json',
'news': 'data/rank_results/beir-news.json',
'scifact': 'data/rank_results/beir-scifact.json',
'fiqa': 'data/rank_results/beir-fiqa.json',
'scidocs': 'data/rank_results/beir-scidocs.json',
'nfc': 'data/rank_results/beir-nfc.json',
'quora': 'data/rank_results/beir-quora.json',
'dbpedia': 'data/rank_results/beir-dbpedia.json',
'fever': 'data/rank_results/beir-fever.json',
'robust04': 'data/rank_results/beir-robust04.json',
'signal': 'data/rank_results/beir-signal.json',
}

THE_INDEX = {
'dl19': 'msmarco-v1-passage',
'dl20': 'msmarco-v1-passage',
'covid': 'beir-v1.0.0-trec-covid.flat',
'arguana': 'beir-v1.0.0-arguana.flat',
'touche': 'beir-v1.0.0-webis-touche2020.flat',
'news': 'beir-v1.0.0-trec-news.flat',
'scifact': 'beir-v1.0.0-scifact.flat',
'fiqa': 'beir-v1.0.0-fiqa.flat',
'scidocs': 'beir-v1.0.0-scidocs.flat',
'nfc': 'beir-v1.0.0-nfcorpus.flat',
'quora': 'beir-v1.0.0-quora.flat',
'dbpedia': 'beir-v1.0.0-dbpedia-entity.flat',
'fever': 'beir-v1.0.0-fever-flat',
'robust04': 'beir-v1.0.0-robust04.flat',
'signal': 'beir-v1.0.0-signal1m.flat',

'mrtydi-ar': 'mrtydi-v1.1-arabic',
'mrtydi-bn': 'mrtydi-v1.1-bengali',
'mrtydi-fi': 'mrtydi-v1.1-finnish',
'mrtydi-id': 'mrtydi-v1.1-indonesian',
'mrtydi-ja': 'mrtydi-v1.1-japanese',
'mrtydi-ko': 'mrtydi-v1.1-korean',
'mrtydi-ru': 'mrtydi-v1.1-russian',
'mrtydi-sw': 'mrtydi-v1.1-swahili',
'mrtydi-te': 'mrtydi-v1.1-telugu',
'mrtydi-th': 'mrtydi-v1.1-thai',
}

THE_TOPICS = {
'dl19': 'dl19-passage',
'dl20': 'dl20-passage',
'covid': 'beir-v1.0.0-trec-covid-test',
'arguana': 'beir-v1.0.0-arguana-test',
'touche': 'beir-v1.0.0-webis-touche2020-test',
'news': 'beir-v1.0.0-trec-news-test',
'scifact': 'beir-v1.0.0-scifact-test',
'fiqa': 'beir-v1.0.0-fiqa-test',
'scidocs': 'beir-v1.0.0-scidocs-test',
'nfc': 'beir-v1.0.0-nfcorpus-test',
'quora': 'beir-v1.0.0-quora-test',
'dbpedia': 'beir-v1.0.0-dbpedia-entity-test',
'fever': 'beir-v1.0.0-fever-test',
'robust04': 'beir-v1.0.0-robust04-test',
'signal': 'beir-v1.0.0-signal1m-test',

'mrtydi-ar': 'mrtydi-v1.1-arabic-test',
'mrtydi-bn': 'mrtydi-v1.1-bengali-test',
'mrtydi-fi': 'mrtydi-v1.1-finnish-test',
'mrtydi-id': 'mrtydi-v1.1-indonesian-test',
'mrtydi-ja': 'mrtydi-v1.1-japanese-test',
'mrtydi-ko': 'mrtydi-v1.1-korean-test',
'mrtydi-ru': 'mrtydi-v1.1-russian-test',
'mrtydi-sw': 'mrtydi-v1.1-swahili-test',
'mrtydi-te': 'mrtydi-v1.1-telugu-test',
'mrtydi-th': 'mrtydi-v1.1-thai-test',

}

from pyserini.search import LuceneSearcher, get_topics, get_qrels
import json
from tqdm import tqdm


def run_retriever(topics, searcher, qrels=None, k=100, qid=None):
ranks = []
if isinstance(topics, str):
hits = searcher.search(topics, k=k)
ranks.append({'query': topics, 'hits': []})
rank = 0
for hit in hits:
rank += 1
content = json.loads(searcher.doc(hit.docid).raw())
if 'title' in content:
content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text']
else:
content = content['contents']
content = ' '.join(content.split())
ranks[-1]['hits'].append({
'content': content,
'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score})
return ranks[-1]

for qid in tqdm(topics):
if qid in qrels:
query = topics[qid]['title']
ranks.append({'query': query, 'hits': []})
hits = searcher.search(query, k=k)
rank = 0
for hit in hits:
rank += 1
content = json.loads(searcher.doc(hit.docid).raw())
if 'title' in content:
content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text']
else:
content = content['contents']
content = ' '.join(content.split())
ranks[-1]['hits'].append({
'content': content,
'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score})
return ranks

def do_retrieval():
for data in ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']:
print('#' * 20)
print(f'Evaluation on {data}')
print('#' * 20)

# Retrieve passages using pyserini BM25.
# Get a specific doc:
# * searcher.num_docs
# * json.loads(searcher.object.reader.document(4).fields[1].fieldsData) -> {"id": "1", "contents": ""}
try:
searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data])
topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20')
qrels = get_qrels(THE_TOPICS[data])
rank_results = run_retriever(topics, searcher, qrels, k=100)

# Store JSON in rank_results to a file
with open(f'rank_results_{data}.json', 'w') as f:
json.dump(rank_results, f, indent=2)
# Store the QRELS of the dataset
with open(f'qrels_{data}.json', 'w') as f:
json.dump(qrels, f, indent=2)
except:
print(f'Failed to retrieve passages for {data}')

for data in ['mrtydi-ar', 'mrtydi-bn', 'mrtydi-fi', 'mrtydi-id', 'mrtydi-ja', 'mrtydi-ko', 'mrtydi-ru', 'mrtydi-sw',
'mrtydi-te', 'mrtydi-th']:
print('#' * 20)
print(f'Evaluation on {data}')
print('#' * 20)

# Retrieve passages using pyserini BM25.
try:
searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data])
topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20')
qrels = get_qrels(THE_TOPICS[data])
rank_results = run_retriever(topics, searcher, qrels, k=100)
rank_results = rank_results[:100]

# Store JSON in rank_results to a file
with open(f'data/rank_results/{data}.json', 'w') as f:
json.dump(rank_results, f, indent=2)
# Store the QRELS of the dataset
with open(f'data/qrels/{data}.json', 'w') as f:
json.dump(qrels, f, indent=2)
except:
print(f'Failed to retrieve passages for {data}')
Loading

0 comments on commit e091fd1

Please sign in to comment.