diff --git a/firelang/models/_fireword.py b/firelang/models/_fireword.py index 25219ab..97dfb88 100644 --- a/firelang/models/_fireword.py +++ b/firelang/models/_fireword.py @@ -70,14 +70,15 @@ def detect_device(self) -> torch.device: return next(iter(self.parameters())).device -class FireWordConfig(dict): +class FireWordConfig: dim: int func: str measure: str - def __init__(self, **kwargs): - dict.__init__(self, **kwargs) - self.__dict__ = self + def __init__(self, dim: int, func: str, measure: str): + self.dim = dim + self.func = func + self.measure = measure class FireWord(FireEmbedding):