Skip to content

Commit

Permalink
[Minor] Keep existing behavior as the default
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jun 3, 2024
1 parent 48024f8 commit 0983f40
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pyvene/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, n, init_orth=True):
# we don't need init if the saved checkpoint has a nice
# starting point already.
# you can also study this if you want, but it is our focus.
# if init_orth:
# torch.nn.init.orthogonal_(weight)
if init_orth:
torch.nn.init.orthogonal_(weight)
self.weight = torch.nn.Parameter(weight, requires_grad=True)

def forward(self, x):
Expand All @@ -33,11 +33,12 @@ def forward(self, x):
class LowRankRotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization."""

def __init__(self, n, m):
def __init__(self, n, m, init_orth=True):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
# torch.nn.init.orthogonal_(self.weight)
if init_orth:
torch.nn.init.orthogonal_(self.weight)

def forward(self, x):
return torch.matmul(x.to(self.weight.dtype), self.weight)
Expand All @@ -46,11 +47,12 @@ def forward(self, x):
class SubspaceLowRankRotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization with subspace."""

def __init__(self, n, m):
def __init__(self, n, m, init_orth=True):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
# torch.nn.init.orthogonal_(self.weight)
if init_orth:
torch.nn.init.orthogonal_(self.weight)

def forward(self, x, l, r):
return torch.matmul(x.to(self.weight.dtype), self.weight[:, l:r])

0 comments on commit 0983f40

Please sign in to comment.