Description
Hey there,
I attempted to re-run the new v3.0.x of IGFold
with openmm
on my system last night. After updating and upgrading the packages, I tried to run the notebook, and I found the following error being thrown from the script AntiBERTy.py
.
File "/xxx/yyy/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
To resolve this, I checked which devices the variables embeddings
and attention_maps
are attached and detached.
They both were created in GPU and only embeddings
is detached from the GPU to the CPU. So, I made the following change:
- Detached them to the CPU, and made both into a list.
`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)
embeddings = embeddings.detach().cpu().tolist()
for i, a in enumerate(attention_mask.detach().cpu().tolist()):
embeddings[i] = embeddings[i][:, a == 1]
`
It threw the following error:
File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
embeddings[i] = embeddings[i][:, a == 1]
TypeError: list indices must be integers or slices, not tuple
To understand the core problem, I wanted to understand embeddings
and attention_maps
. So,
`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)
print(f"embeddings: {embeddings} length:{embeddings.shape}")
print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
embeddings = embeddings.detach().cpu().tolist()
for i, a in enumerate(attention_mask.detach().cpu().tolist()):
embeddings[i] = embeddings[i][:, a == 1]
`
Details
embeddings: tensor([[[[ 2.8550e-01, -9.1305e-01, -3.2084e-01, ..., -2.6094e-01,
-1.1316e-01, 6.2532e-02],
[ 1.2959e-01, -2.3578e-01, -9.5074e-01, ..., -3.4716e-01,
3.9048e-01, -7.9039e-01],
[-1.1861e-01, -8.5111e-01, 1.7778e-01, ..., -6.4417e-01,
-1.6268e-01, -7.4019e-01],
...,
[ 1.4825e+00, 1.0562e+00, -5.5296e-01, ..., 4.6048e-02,
-5.8749e-01, 3.5935e-01],
[ 1.1087e+00, 8.3452e-01, -4.6560e-01, ..., -6.5979e-01,
7.0711e-02, 1.3638e+00],
[ 7.1583e-01, 8.4463e-01, 7.4550e-01, ..., 5.5646e-01,
-6.0864e-01, 1.2408e+00]],
[[ 8.2428e-01, -6.0705e-01, -9.0634e-01, ..., -4.5286e-02,
-6.8834e-02, 4.4105e-01],
[ 7.2001e-01, 6.3411e-01, -1.0107e+00, ..., -4.3047e-01,
-5.7251e-01, -6.7011e-01],
[ 4.6859e-01, -8.5742e-01, -1.5053e-02, ..., -2.8734e-01,
-1.0233e+00, -3.6219e-01],
...,
[ 1.0764e+00, 1.1695e+00, -6.8277e-01, ..., 2.8122e-02,
-9.8832e-01, 1.4659e-01],
[ 8.8104e-01, 1.1147e+00, -7.1646e-01, ..., -1.0783e-01,
-7.9473e-01, 1.0538e+00],
[ 6.3558e-01, 9.0190e-01, 4.0055e-01, ..., 3.1800e-01,
-1.0868e+00, 9.7025e-01]],
[[ 9.6156e-01, -9.6647e-01, -1.4004e+00, ..., -6.3557e-01,
4.1958e-01, -1.8568e-01],
[ 3.0844e-01, 1.0339e+00, -1.5486e+00, ..., 2.1584e-01,
-3.8619e-01, -8.9405e-01],
[ 4.5382e-01, -3.8623e-01, 1.7961e-01, ..., -1.4155e-01,
-1.1880e+00, -5.4827e-01],
...,
[ 9.9114e-01, 5.7983e-01, -2.9399e-01, ..., -4.6010e-01,
-6.7488e-01, -6.2466e-01],
[ 7.5153e-01, 4.8691e-01, -5.4032e-01, ..., 2.6127e-01,
-1.0607e+00, 7.8277e-01],
[ 8.5168e-01, 4.9293e-01, -2.6708e-01, ..., 3.8526e-01,
-1.1824e+00, 8.5203e-01]],
...,
[[ 1.2814e+00, -4.3900e-01, -3.2785e-01, ..., -1.2414e+00,
-6.3775e-01, -1.3176e+00],
[ 3.0157e-01, 1.6172e+00, -1.3343e+00, ..., -1.2285e+00,
-5.5167e-01, -1.8283e+00],
[ 3.5919e-01, -2.6482e-01, -1.0645e+00, ..., -4.3375e-02,
-3.2065e-01, -9.8966e-01],
...,
[ 1.8181e+00, -1.6646e-01, -1.2666e+00, ..., 1.0637e+00,
1.4646e+00, -1.6298e+00],
[ 1.0763e+00, -5.1882e-01, -6.8510e-01, ..., 1.3576e+00,
1.2688e+00, -1.4657e+00],
[ 1.7986e+00, -7.4009e-02, -1.2577e+00, ..., 1.0660e+00,
1.4812e+00, -1.4051e+00]],
[[ 1.2025e+00, -5.5392e-01, -1.0193e+00, ..., -8.1229e-01,
-2.3811e-01, -4.7275e-01],
[ 6.5538e-01, 1.1917e+00, -5.2697e-01, ..., -8.7801e-01,
-7.4126e-01, -1.9144e+00],
[ 2.5875e-01, -7.9232e-01, -8.5029e-01, ..., 6.4324e-02,
-8.0997e-02, -1.9687e+00],
...,
[ 1.4830e+00, -1.9244e-01, -6.8066e-01, ..., 2.1269e-01,
1.0873e+00, -1.3896e+00],
[ 5.3997e-01, -1.4820e-01, -2.0483e-01, ..., 7.3495e-01,
8.6871e-01, -1.3526e+00],
[ 1.6477e+00, -5.3092e-02, -7.1276e-01, ..., 3.2879e-01,
1.1778e+00, -9.6469e-01]],
[[ 1.5494e+00, -9.5254e-01, -8.3588e-01, ..., -4.2762e-01,
6.2013e-01, 1.0120e-02],
[ 4.4904e-02, 7.8505e-01, -1.0384e+00, ..., -7.8334e-02,
-1.7476e-01, -1.6311e+00],
[ 1.7894e-01, -9.9010e-01, -1.1633e+00, ..., 6.0122e-01,
-1.0615e-01, -1.5358e+00],
...,
[ 1.2771e+00, -1.8352e-01, -1.4466e+00, ..., -6.2605e-01,
1.2011e+00, -2.0856e+00],
[ 5.6284e-01, -9.5801e-02, -1.1209e+00, ..., -5.1828e-01,
4.9442e-01, -1.5956e+00],
[ 1.1071e+00, 3.0336e-01, -1.8048e+00, ..., -3.8724e-01,
1.1147e+00, -1.5361e+00]]],
[[[ 2.8550e-01, -9.1305e-01, -3.2084e-01, ..., -2.6094e-01,
-1.1316e-01, 6.2532e-02],
[ 5.0035e-01, 5.4549e-01, 3.4283e-01, ..., -3.0739e-01,
-4.9315e-01, -1.1373e+00],
[-4.0275e-01, 2.1443e-02, 2.0123e-01, ..., -2.4489e-01,
8.3188e-01, -6.5645e-01],
...,
[ 4.0514e-01, -3.2213e-01, 3.7994e-01, ..., 1.2408e-01,
6.3095e-01, 9.2037e-03],
[ 1.9132e-01, -4.4131e-01, 4.2406e-01, ..., -2.6266e-01,
9.8391e-01, 5.5734e-01],
[ 4.0278e-01, -4.9534e-02, 3.3810e-01, ..., 1.4354e-01,
8.4249e-01, 4.0723e-01]],
[[ 6.2418e-02, -6.1317e-01, -1.5439e+00, ..., -3.1803e-01,
-2.0041e-01, 4.4618e-01],
[-6.7039e-02, 1.2193e+00, -5.0822e-01, ..., 3.5469e-01,
2.6262e-02, -7.7125e-01],
[-9.5805e-01, 1.4456e-01, -1.8127e-01, ..., 3.6328e-01,
1.4936e+00, -4.5747e-02],
...,
[ 6.8287e-02, 8.2539e-01, 5.4192e-02, ..., -1.1069e-01,
6.6216e-01, 7.4946e-01],
[-1.9581e-01, 6.8329e-01, -2.6928e-01, ..., -7.0956e-01,
7.8344e-01, 1.4804e+00],
[-4.1462e-02, 8.8683e-01, -5.2905e-01, ..., -2.5274e-01,
7.1604e-01, 1.2256e+00]],
[[ 3.5130e-01, -1.5874e+00, -1.7016e+00, ..., 6.8850e-01,
-5.8646e-01, 1.7784e-01],
[ 1.1386e-01, 1.3657e+00, -8.2388e-01, ..., 4.7490e-01,
1.2626e+00, -3.1313e-01],
[-1.1854e+00, -1.1600e-03, -7.3433e-01, ..., 7.6139e-01,
1.6375e+00, 1.8955e-01],
...,
[-6.9969e-01, 1.1508e+00, 7.0558e-02, ..., 4.2873e-01,
5.6067e-01, 5.2250e-01],
[-5.0788e-01, 6.6331e-01, -6.1032e-01, ..., -2.3532e-01,
8.2221e-01, 7.9204e-01],
[-2.6820e-01, 8.5643e-01, -4.7090e-01, ..., -2.8118e-01,
6.5296e-01, 6.8785e-01]],
...,
[[-9.0217e-02, -2.6741e-01, -1.0890e+00, ..., 1.8798e+00,
-3.2522e-03, -1.5653e-01],
[-6.9740e-01, 1.4951e+00, -6.4886e-01, ..., -1.3687e-01,
1.4956e+00, 3.7487e-01],
[-1.6580e-01, 1.1264e-01, -7.6442e-01, ..., 4.3402e-01,
1.9541e+00, 1.2029e+00],
...,
[ 1.9953e-01, 2.6025e+00, -4.9651e-01, ..., 5.0344e-01,
-1.2114e-02, 3.9688e-01],
[-1.0917e+00, 1.2115e+00, 6.2053e-01, ..., 8.5435e-01,
-4.5358e-02, 3.5120e-01],
[ 6.1694e-01, 2.1130e+00, -1.1016e+00, ..., 2.8187e-01,
9.5419e-02, -3.5959e-01]],
[[ 5.0400e-01, -5.3220e-01, -1.0173e+00, ..., 2.1676e+00,
-3.6843e-01, -1.8500e-01],
[-2.1364e-01, 9.2027e-01, -2.5382e-01, ..., 1.1757e-01,
9.4363e-01, 6.0816e-01],
[-1.0163e-01, -3.2413e-02, -7.2567e-01, ..., 1.1070e+00,
1.3306e+00, 1.0462e+00],
...,
[ 3.0022e-01, 2.6991e+00, -4.7573e-01, ..., -1.0428e-01,
-7.8721e-02, 1.1695e+00],
[-1.0961e+00, 6.7808e-01, -3.0792e-01, ..., 7.1660e-01,
-2.0900e-01, 4.5738e-01],
[ 8.4948e-01, 1.9340e+00, -1.1624e+00, ..., -2.2008e-01,
4.5761e-01, 9.6474e-01]],
[[-2.9150e-01, 4.8298e-01, -3.7572e-01, ..., 2.4827e+00,
-1.9686e-01, 2.9108e-01],
[-4.5003e-01, 4.0321e-01, -1.0218e+00, ..., -1.9378e-01,
5.3391e-01, 3.8499e-01],
[-7.7064e-02, -5.0206e-01, -1.3377e+00, ..., 9.1953e-01,
5.2488e-01, 1.2372e-01],
...,
[ 7.2962e-01, 1.8133e+00, 2.9414e-01, ..., 7.3038e-01,
-2.0271e-01, 2.1481e+00],
[-7.7066e-01, -1.0586e-01, -9.3787e-02, ..., 1.0239e+00,
-2.1658e-01, 9.3203e-01],
[ 1.0556e+00, 9.7592e-01, -1.2148e+00, ..., -4.7689e-02,
-1.4709e-02, 2.9145e-01]]]], device='cuda:0') length:torch.Size([2, 9, 120, 512])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) length:torch.Size([2, 120])
I made this change:
`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)
print(f"embeddings: {embeddings} length:{embeddings.shape}")
print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
embeddings = embeddings.detach().cpu()
for i, a in enumerate(attention_mask.detach().cpu()):
embeddings[i] = embeddings[i][:, a == 1]
`
So, finally, I tried to replace them as tensors and tried to replace, but it obviously threw tensor dimensions mismatch error:
File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: The expanded size of the tensor (120) must match the existing size (109) at non-singleton dimension 1. Target sizes: [9, 120, 512]. Tensor sizes: [9, 109, 512]
Because embeddings
size is: [2, 9, 120, 512].
Whereas attention_mask
size is: [2, 120].
What is the end goal of the following snippet? Why does this throw an error? Please help me resolv this.