Skip to content

Commit

Permalink
add safetensors support when reading checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
honglu2875 committed Jan 26, 2024
1 parent 748e259 commit 7236f6e
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 9 deletions.
11 changes: 6 additions & 5 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def build_to_file(
def combine_datasets_from_file(cls, *args: str, output_path: str):
"""Function for concatenating jsonl files, checking for duplicates"""
logger = setup_logger()

for input_path in args:
assert os.path.isfile(input_path), f"{input_path} doesn't exist"

Expand All @@ -216,7 +216,6 @@ def combine_datasets_from_file(cls, *args: str, output_path: str):
f"{len(hashes)} unique midi_dicts and {dupe_cnt} duplicates so far"
)


logger.info(
f"Found {len(hashes)} unique midi_dicts and {dupe_cnt} duplicates"
)
Expand Down Expand Up @@ -603,6 +602,7 @@ def get_seqs(
if not any(proc.is_alive() for proc in workers):
break


def reservoir(_iterable: Iterable, k: int):
_reservoir = []
for entry in _iterable:
Expand All @@ -613,10 +613,11 @@ def reservoir(_iterable: Iterable, k: int):
random.shuffle(_reservoir)
yield from _reservoir
_reservoir = []

if _reservoir != []:
yield from _reservoir


class PretrainingDataset(TrainingDataset):
def __init__(self, dir_path: str, tokenizer: Tokenizer):
super().__init__(tokenizer=tokenizer)
Expand Down Expand Up @@ -734,7 +735,7 @@ def _build_epoch(_save_path, _midi_dataset):
while len(buffer) >= max_seq_len:
writer.write(buffer[:max_seq_len])
buffer = buffer[max_seq_len:]

_idx += 1
if _idx % 250 == 0:
logger.info(f"Finished processing {_idx}")
Expand Down Expand Up @@ -776,7 +777,7 @@ def _build_epoch(_save_path, _midi_dataset):
)
for idx in range(num_epochs):
logger.info(f"Building epoch {idx}/{num_epochs - 1}...")

# Reload the dataset on each iter
if midi_dataset_path:
midi_dataset = jsonlines.open(midi_dataset_path, "r")
Expand Down
7 changes: 4 additions & 3 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils for data/MIDI processing."""

import hashlib
import json
import re
Expand Down Expand Up @@ -318,9 +319,9 @@ def _extract_track_data(track: mido.MidiTrack):
if len(notes_to_close) > 0 and len(notes_to_keep) > 0:
# Note-on on the same tick but we already closed
# some previous notes -> it will continue, keep it.
last_note_on[
(message.note, message.channel)
] = notes_to_keep
last_note_on[(message.note, message.channel)] = (
notes_to_keep
)
else:
# Remove the last note on for this instrument
del last_note_on[(message.note, message.channel)]
Expand Down
1 change: 1 addition & 0 deletions aria/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Includes (PyTorch) transformer model and config classes."""

from dataclasses import dataclass
from typing import Optional, Union

Expand Down
11 changes: 10 additions & 1 deletion aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,16 @@ def sample(args):
)

ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided
model_state = torch.load(ckpt_path, map_location=device)
if ckpt_path.endswith("safetensors"):
try:
from safetensors.torch import load_file
except ImportError as e:
raise ImportError(
f"Please install safetensors in order to read from the checkpoint: {ckpt_path}"
) from e
model_state = load_file(ckpt_path)
else:
model_state = torch.load(ckpt_path, map_location=device)
model_name = _get_model_name(
args.m, model_state
) # infer model name if not provided
Expand Down
1 change: 1 addition & 0 deletions aria/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains generation/sampling code"""

# This file contains code from https://github.com/facebookresearch/llama which
# is available under the following license:

Expand Down

0 comments on commit 7236f6e

Please sign in to comment.