From ef84a0cf13d48406728f48ff8ef7a39d28ce4ec7 Mon Sep 17 00:00:00 2001 From: duxin Date: Thu, 8 Dec 2022 05:51:05 +0900 Subject: [PATCH] updated FireWordConfig --- firelang/models/_fireword.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/firelang/models/_fireword.py b/firelang/models/_fireword.py index 25219ab..97dfb88 100644 --- a/firelang/models/_fireword.py +++ b/firelang/models/_fireword.py @@ -70,14 +70,15 @@ def detect_device(self) -> torch.device: return next(iter(self.parameters())).device -class FireWordConfig(dict): +class FireWordConfig: dim: int func: str measure: str - def __init__(self, **kwargs): - dict.__init__(self, **kwargs) - self.__dict__ = self + def __init__(self, dim: int, func: str, measure: str): + self.dim = dim + self.func = func + self.measure = measure class FireWord(FireEmbedding):