Skip to content

Commit

Permalink
replaced lib transformer model with local implementation; only mkdir …
Browse files Browse the repository at this point in the history
…checkpoint folder for one of the devices for distributed training
  • Loading branch information
Richard5678 committed Jan 6, 2025
1 parent 733c22a commit d7d8fbc
Showing 1 changed file with 193 additions and 171 deletions.
364 changes: 193 additions & 171 deletions examples/machine_translation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from torch.distributed.elastic.multiprocessing.errors import record
from torch.nn.parallel import DistributedDataParallel as DDP

from mytransformers.models import TransformerEncoderDecoder

import os
import warnings
import signal
Expand All @@ -33,63 +35,63 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class TransformerEncoderDecoder(nn.Module):
def __init__(
self,
vocab_size_en,
vocab_size_zh,
embed_dim,
num_heads,
num_layers,
dim_feedforward,
dropout,
):
super(TransformerEncoderDecoder, self).__init__()
self.embedding_en = nn.Embedding(vocab_size_en, embed_dim)
self.embedding_zh = nn.Embedding(vocab_size_zh, embed_dim)
self.transformer = nn.Transformer(
d_model=embed_dim,
nhead=num_heads,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="relu",
batch_first=True,
)
self.fc_out = nn.Linear(embed_dim, vocab_size_zh)

def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
src_emb = self.embedding_en(src) * math.sqrt(self.transformer.d_model)
tgt_emb = self.embedding_zh(tgt) * math.sqrt(self.transformer.d_model)

# Ensure masks are of the same type
src_key_padding_mask = (
src_key_padding_mask.to(torch.bool)
if src_key_padding_mask is not None
else None
)
tgt_key_padding_mask = (
tgt_key_padding_mask.to(torch.bool)
if tgt_key_padding_mask is not None
else None
)

subsequent_mask = torch.triu(
torch.ones(tgt.size(1), tgt.size(1), dtype=torch.bool), diagonal=1
).to(device)

output = self.transformer(
src_emb,
tgt_emb,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_mask=subsequent_mask,
)
return self.fc_out(output)

def load_model(self, path):
self.load_state_dict(torch.load(path))
# class TransformerEncoderDecoder(nn.Module):
# def __init__(
# self,
# vocab_size_en,
# vocab_size_zh,
# embed_dim,
# num_heads,
# num_layers,
# dim_feedforward,
# dropout,
# ):
# super(TransformerEncoderDecoder, self).__init__()
# self.embedding_en = nn.Embedding(vocab_size_en, embed_dim)
# self.embedding_zh = nn.Embedding(vocab_size_zh, embed_dim)
# self.transformer = nn.Transformer(
# d_model=embed_dim,
# nhead=num_heads,
# num_encoder_layers=num_layers,
# num_decoder_layers=num_layers,
# dim_feedforward=dim_feedforward,
# dropout=dropout,
# activation="relu",
# batch_first=True,
# )
# self.fc_out = nn.Linear(embed_dim, vocab_size_zh)

# def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
# src_emb = self.embedding_en(src) * math.sqrt(self.transformer.d_model)
# tgt_emb = self.embedding_zh(tgt) * math.sqrt(self.transformer.d_model)

# # Ensure masks are of the same type
# src_key_padding_mask = (
# src_key_padding_mask.to(torch.bool)
# if src_key_padding_mask is not None
# else None
# )
# tgt_key_padding_mask = (
# tgt_key_padding_mask.to(torch.bool)
# if tgt_key_padding_mask is not None
# else None
# )

# subsequent_mask = torch.triu(
# torch.ones(tgt.size(1), tgt.size(1), dtype=torch.bool), diagonal=1
# ).to(device)

# output = self.transformer(
# src_emb,
# tgt_emb,
# src_key_padding_mask=src_key_padding_mask,
# tgt_key_padding_mask=tgt_key_padding_mask,
# tgt_mask=subsequent_mask,
# )
# return self.fc_out(output)

# def load_model(self, path):
# self.load_state_dict(torch.load(path))


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -195,6 +197,8 @@ def train_model(model, train_loader, epochs=50, d_model=512, warmup_steps=4000,

epoch_loss += loss.item()
global_step += 1 # Increment the global step

# break

epoch_losses.append(epoch_loss / len(train_loader))

Expand Down Expand Up @@ -288,128 +292,146 @@ def signal_handler(sig, frame):
@record
def main():

setup_distributed()

# Create a timestamped directory for checkpoints
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
checkpoint_dir = f"checkpoints_{timestamp}"
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Checkpoint directory created at: {checkpoint_dir}")

# Load data

# Setup distributed training
setup_distributed()
rank = dist.get_rank()
world_size = dist.get_world_size()

# Only the process with rank 0 creates the checkpoint directory
if rank == 0:
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
checkpoint_dir = f"checkpoints_{timestamp}"
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Checkpoint directory created at: {checkpoint_dir}")
else:
checkpoint_dir = None # Other ranks do not create the directory

# Ensure all processes wait until the checkpoint directory is created
dist.barrier()

# Load data
cache_dir = "~/data"
os.environ["HF_DATASETS_CACHE"] = cache_dir
dataset = load_dataset(
"iwslt2017", "iwslt2017-en-zh", cache_dir=cache_dir, trust_remote_code=True
)
# dataset["train"] = dataset["train"].select(range(10))
# dataset["test"] = dataset["test"].select(range(10))
print(f"train dataset size: {len(dataset['train'])}")

# Tokenize data
tokenizer_en = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer_zh = AutoTokenizer.from_pretrained("bert-base-chinese")

# max_seq_length = get_max_seq_length(percentile=100, plot=True)
# print(f"max_seq_length: {max_seq_length}")
# return
max_seq_length = 104 # 99 percentile

# Tokenize data using batch processing with progress bar
train_encodings_en = tokenizer_en.batch_encode_plus(
[
item["translation"]["en"]
for item in tqdm(dataset["train"], desc="Tokenizing English Train Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
train_encodings_zh = tokenizer_zh.batch_encode_plus(
[
item["translation"]["zh"]
for item in tqdm(dataset["train"], desc="Tokenizing Chinese Train Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
test_encodings_en = tokenizer_en.batch_encode_plus(
[
item["translation"]["en"]
for item in tqdm(dataset["test"], desc="Tokenizing English Test Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
test_encodings_zh = tokenizer_zh.batch_encode_plus(
[
item["translation"]["zh"]
for item in tqdm(dataset["test"], desc="Tokenizing Chinese Test Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
print("tokenization done")

cache_dir = "~/data"
os.environ["HF_DATASETS_CACHE"] = cache_dir
dataset = load_dataset(
"iwslt2017", "iwslt2017-en-zh", cache_dir=cache_dir, trust_remote_code=True
)
# dataset["train"] = dataset["train"].select(range(10))
# dataset["test"] = dataset["test"].select(range(10))
print(f"train dataset size: {len(dataset['train'])}")
train_dataset = CustomDataset(train_encodings_en, train_encodings_zh)
test_dataset = CustomDataset(test_encodings_en, test_encodings_zh)

# Tokenize data
tokenizer_en = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer_zh = AutoTokenizer.from_pretrained("bert-base-chinese")
# Modify DataLoader to use DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=False, num_workers=4, sampler=train_sampler
)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
test_loader = DataLoader(
test_dataset, batch_size=8, shuffle=False, num_workers=4, sampler=test_sampler
)

# max_seq_length = get_max_seq_length(percentile=100, plot=True)
# print(f"max_seq_length: {max_seq_length}")
# return
max_seq_length = 104 # 99 percentile

# Tokenize data using batch processing with progress bar
train_encodings_en = tokenizer_en.batch_encode_plus(
[
item["translation"]["en"]
for item in tqdm(dataset["train"], desc="Tokenizing English Train Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
train_encodings_zh = tokenizer_zh.batch_encode_plus(
[
item["translation"]["zh"]
for item in tqdm(dataset["train"], desc="Tokenizing Chinese Train Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
test_encodings_en = tokenizer_en.batch_encode_plus(
[
item["translation"]["en"]
for item in tqdm(dataset["test"], desc="Tokenizing English Test Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
test_encodings_zh = tokenizer_zh.batch_encode_plus(
[
item["translation"]["zh"]
for item in tqdm(dataset["test"], desc="Tokenizing Chinese Test Data")
],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
)
print("tokenization done")
# Create model
# model = TransformerEncoderDecoder(
# vocab_size_en=tokenizer_en.vocab_size,
# vocab_size_zh=tokenizer_zh.vocab_size,
# embed_dim=768,
# num_heads=16,
# num_layers=3,
# dim_feedforward=768,
# dropout=0.1,
# ).to(device)
# model = TransformerEncoderDecoder(
# vocab_size_en=tokenizer_en.vocab_size,
# vocab_size_zh=tokenizer_zh.vocab_size,
# embed_dim=512,
# num_heads=8,
# num_layers=6,
# dim_feedforward=512,
# dropout=0.1,
# ).to(device)
# model.load_model('model_2024-12-19_07:41:27.pth') # 50 epochs


model = TransformerEncoderDecoder(
embed_dim=512,
num_heads=8,
vocab_size=tokenizer_en.vocab_size,
num_labels=tokenizer_zh.vocab_size,
max_seq_length=max_seq_length,
num_layers=6,
hidden_dim=512,
).to(device)

train_dataset = CustomDataset(train_encodings_en, train_encodings_zh)
test_dataset = CustomDataset(test_encodings_en, test_encodings_zh)
# Wrap the model with DistributedDataParallel
model = DDP(model, device_ids=[torch.cuda.current_device()])

# Modify DataLoader to use DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=False, num_workers=4, sampler=train_sampler
)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
test_loader = DataLoader(
test_dataset, batch_size=8, shuffle=False, num_workers=4, sampler=test_sampler
)
# Train model
train_model(model, train_loader, start_epoch=0, checkpoint_dir=checkpoint_dir)

# Create model
# model = TransformerEncoderDecoder(
# vocab_size_en=tokenizer_en.vocab_size,
# vocab_size_zh=tokenizer_zh.vocab_size,
# embed_dim=768,
# num_heads=16,
# num_layers=3,
# dim_feedforward=768,
# dropout=0.1,
# ).to(device)
model = TransformerEncoderDecoder(
vocab_size_en=tokenizer_en.vocab_size,
vocab_size_zh=tokenizer_zh.vocab_size,
embed_dim=512,
num_heads=8,
num_layers=6,
dim_feedforward=512,
dropout=0.1,
).to(device)
model.load_model('model_2024-12-19_07:41:27.pth') # 50 epochs

# Wrap the model with DistributedDataParallel
model = DDP(model, device_ids=[torch.cuda.current_device()])

# Train model
train_model(model, train_loader, start_epoch=50, checkpoint_dir=checkpoint_dir)

# Save model
if dist.get_rank() == 0:
model_checkpoint_path = os.path.join(checkpoint_dir, f"model_{timestamp}.pth")
torch.save(model.module.state_dict(), model_checkpoint_path)
print(f"Final model saved at: {model_checkpoint_path}")
# Save model
if dist.get_rank() == 0:
model_checkpoint_path = os.path.join(checkpoint_dir, f"model_{timestamp}.pth")
torch.save(model.module.state_dict(), model_checkpoint_path)
print(f"Final model saved at: {model_checkpoint_path}")

# Evaluate model
evaluate_model(model, test_loader)
# Evaluate model
evaluate_model(model, test_loader)

cleanup_distributed()
cleanup_distributed()


def get_max_seq_length(percentile=95, plot=False):
Expand Down

0 comments on commit d7d8fbc

Please sign in to comment.