Repository for the trial task of NUS WING - Reimplementation of kNN-MT using hugging face and pytorch.
Link to paper:
We use the DE-EN translation using Facebook's WMT19 pretrained model as an example. Their original model can be found here.
The repo supports WMT19 data (Wikimedia Foundation, 2019) in general. The instructions below only show DE-EN translation, but other pairs should work fine.
If you want to use a pre-trained index for the datastore and directly run inference, you can directly jump to 3. Evaluation with datastore.
If you want to start from scratch and experience the every step, you should start from 1. Extracting raw features.
Please create a suitable python environment (using venv or conda).
Inside the environment, run pip install -r requirements.txt
to install the dependencies.
As we modified the transformers library, please install the library by:
cd transformers
pip install -e .
Optional (if you want to use the datastore that we trained):
To use the data store that we trained, please download the trained index from here.
After downloading, move it to ./datastore_1/
, so this directory have a
and a index.trained
Before building the datastore, we need to extract all features in the form of (hidden_statei, wordi+1). We do so by the following:
python \
--model_name_or_path facebook/wmt19-de-en \
--source_lang de \
--target_lang en \
--dataset_name wmt19 \
--dataset_config_name de-en \
--save_path saved_gen \
The --dataset_name
and --dataset_config_name
specify which datasets we should use.
The --save_path
specifies where to save the generated features. Please make sure there is enough disk space.
The --percentage
specifies how many percent of the training data will be loaded and passed through the model. To save time and space for local testing, it is best to use a very small number (e.g. 1) first.
The file is adapted from the official example code of huggging face here. We removed the evaluation dataset completely and only run a single epoch forward inference on the training set. We delete all unused code related to training (setting up optimizer, etc.).
Although in the paper the datastore consists of key-value pairs, in reality it is a trained FAISS index (after clustering).
This is an essential step. Otherwise, the kNN search becomes infeasible.
We can train an index by:
python --feature_dir saved_gen --output_dir datastore_1
The --feature_dir
is the stored features generated in Step 1, and the --output_dir
is the path in which you want your datastore to be saved.
This is the key step of the k-NN MT pipeline. As we are using the huggingface transformers library for the pretrained transformers model, we continue using it.
We modified some part of the library about beam search.
To run the inference and evaluation on the validation dataset, do:
export PAIR=de-en
export DATA_DIR=data/$PAIR
export SAVE_DIR=data/$PAIR
export BS=8
export NUM_BEAMS=15
python facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt \
--reference_path $DATA_DIR/ \
--score_path $SAVE_DIR/test_bleu.json \
--bs $BS \
--task translation \
--num_beams $NUM_BEAMS \
--datastore_path datastore_1 \
--lambda_value 0.8 \
--k 64
The --datastore_path
parameter should be the datastore path you saved during Step 2, or it can be a pretrained index (in this repo, datastore_1/ contains the index.)
The --lambda_value
parameter determines how much you want to interpolate between the generated score (lambda) and the knn_score (1-lambda).
The --k
parameter is for kNN search.
The final score for each token is:
final_score = lambda * scoregen + (1-lambda) * scorekNN
Please note that you might not see a big difference between the BLEU scores for different lambda values. This is because we only train the datastore on a very small fraction of the training data (1%), due to computation resources constraint. It is too small for any improvement to happen. However, if you have enough resources, you should be able to see the improvement.
For k==64, lambda_value==0.8, you should get: {'bleu': 40.9568, 'n_obs': 2000, 'runtime': 5892, 'seconds_per_sample': 2.946}
For the baseline without kNN search, remove the datastore_path, k, and lambda_value. You will get:
python facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt \
--reference_path $DATA_DIR/ \
--score_path $SAVE_DIR/test_bleu.json \
--bs $BS \
--task translation \
--num_beams $NUM_BEAMS
The baseline result should be: {'bleu': 41.3159, 'n_obs': 2000, 'runtime': 143, 'seconds_per_sample': 0.0715}