Skip to content

Commit

Permalink
Fixed a bug in save / load of a FireWord model.
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Mar 7, 2023
1 parent ef84a0c commit 7161799
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions firelang/models/_fireword.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def loss_skipgram(
return loss

@staticmethod
def from_pretrained(dirpath) -> FireWord:
def from_pretrained(dirpath, strict: bool = True) -> FireWord:
dirpath = os.path.abspath(dirpath)
if not os.path.exists(dirpath):
raise FileNotFoundError(f"Directory not found at {dirpath}")
Expand All @@ -263,7 +263,7 @@ def from_pretrained(dirpath) -> FireWord:
# state_dict
word = FireWord(config=config, vocab=vocab)
state_dict = torch.load(f"{dirpath}/pytorch_model.bin")
word.load_state_dict(state_dict)
word.load_state_dict(state_dict, strict=strict)
return word

def save(self, dirpath):
Expand All @@ -275,7 +275,7 @@ def save(self, dirpath):

# config
with open(f"{dirpath}/config.json", "wt") as f:
json.dump(self.config, f)
json.dump(self.config.__dict__, f)

# vocab
self.vocab.to_json(f"{dirpath}/vocab.json")
Expand Down

0 comments on commit 7161799

Please sign in to comment.