Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReRank Transformer #22

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
829fee5
adding initial implementation of rhs transformer to repo
andyhuang-kumo Aug 8, 2024
12a83bd
adding rhs transformer to the benchmark script
andyhuang-kumo Aug 9, 2024
a3a9779
rhs transformer upload
andyhuang-kumo Aug 12, 2024
988781b
updating tr
andyhuang-kumo Aug 12, 2024
e5faa22
running code
andyhuang-kumo Aug 12, 2024
5e904ea
running code
andyhuang-kumo Aug 12, 2024
d6ff169
adding transformer changes
andyhuang-kumo Aug 13, 2024
4884b83
permute the index, the rhs and then reverse it
andyhuang-kumo Aug 13, 2024
08041cb
removing none to replace with None
andyhuang-kumo Aug 13, 2024
29c70ad
add time fuse encoder to extract time pe
andyhuang-kumo Aug 14, 2024
f58b528
update hyperparameter options
andyhuang-kumo Aug 14, 2024
8f4966c
adding rerank_transformer
andyhuang-kumo Aug 19, 2024
90a8817
setting zeros for not used logits in rerank transformer
andyhuang-kumo Aug 19, 2024
1eb769c
adding reranker transformer
andyhuang-kumo Aug 21, 2024
ea58dd0
Merge branch 'master' into rhs_tr
andyhuang-kumo Aug 21, 2024
75445ce
updating RHS transformer code
andyhuang-kumo Aug 21, 2024
7380a17
updating rerank_transformer
andyhuang-kumo Aug 23, 2024
7622b94
push current version
andyhuang-kumo Aug 23, 2024
dc73f7c
converging to the follow implementation
andyhuang-kumo Aug 23, 2024
38f9cf4
training both from epoch 0, potentially try better ways
andyhuang-kumo Aug 23, 2024
8e846dd
for transformer, not training with nodes whose prediction is not within
andyhuang-kumo Aug 25, 2024
41a38cf
commiting before clean up
andyhuang-kumo Aug 26, 2024
a181f45
semi working version of rerank transformer, still needs work for
andyhuang-kumo Aug 27, 2024
b86641f
somewhat stable version of the idea
andyhuang-kumo Aug 28, 2024
66db12f
adding test for max_map
andyhuang-kumo Aug 30, 2024
dfd559f
adding how we are testing the rerank upper bound
andyhuang-kumo Aug 30, 2024
342b90a
commit
andyhuang-kumo Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
permute the index, the rhs and then reverse it
  • Loading branch information
andyhuang-kumo committed Aug 13, 2024
commit 4884b837b081c0f0fd363a93c6528851949511fc
15 changes: 13 additions & 2 deletions hybridgnn/nn/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def reset_parameters(self):
self.fc.reset_parameters()


def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor:
def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor:
r"""Returns the attended to rhs embeddings
"""
rhs_embed = self.lin(rhs_embed)
Expand All @@ -73,16 +73,27 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor:
torch.arange(rhs_embed.size(0), device=rhs_embed.device))

# #! if we sort the index, we need to sort the rhs_embed
# sorted_index, _ = torch.sort(index)
sorted_index, sorted_idx = torch.sort(index, stable=True)
index = index[sorted_idx]
rhs_embed = rhs_embed[sorted_idx]
reverse = self.inverse_permutation(sorted_idx)
# assert torch.equal(index, sorted_index)

x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size)
for block in self.blocks:
x = block(x, x)
x = x[mask]
x = x.view(-1, self.hidden_channels)
x = x[reverse]
# x = x.gather(1, sorted_idx.argsort(1))

return self.fc(x)

def inverse_permutation(self,perm):
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.size(0), device=perm.device)
return inv


class RotaryPositionalEmbeddings(torch.nn.Module):
def __init__(self, channels, base=10000):
Expand Down
Loading