Skip to content

Commit

Permalink
[Minor] Remove ortho init for DAS
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jun 3, 2024
1 parent f92a379 commit 48024f8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pyvene/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, n, init_orth=True):
# we don't need init if the saved checkpoint has a nice
# starting point already.
# you can also study this if you want, but it is our focus.
if init_orth:
torch.nn.init.orthogonal_(weight)
# if init_orth:
# torch.nn.init.orthogonal_(weight)
self.weight = torch.nn.Parameter(weight, requires_grad=True)

def forward(self, x):
Expand All @@ -37,7 +37,7 @@ def __init__(self, n, m):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
torch.nn.init.orthogonal_(self.weight)
# torch.nn.init.orthogonal_(self.weight)

def forward(self, x):
return torch.matmul(x.to(self.weight.dtype), self.weight)
Expand All @@ -50,7 +50,7 @@ def __init__(self, n, m):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
torch.nn.init.orthogonal_(self.weight)
# torch.nn.init.orthogonal_(self.weight)

def forward(self, x, l, r):
return torch.matmul(x.to(self.weight.dtype), self.weight[:, l:r])

0 comments on commit 48024f8

Please sign in to comment.