Skip to content

issues in AntiBERTyRunner.py #1

Open
@Elmiar0642

Description

Hey there,

I attempted to re-run the new v3.0.x of IGFold with openmm on my system last night. After updating and upgrading the packages, I tried to run the notebook, and I found the following error being thrown from the script AntiBERTy.py.

File "/xxx/yyy/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

To resolve this, I checked which devices the variables embeddings and attention_maps are attached and detached.

They both were created in GPU and only embeddings is detached from the GPU to the CPU. So, I made the following change:

  • Detached them to the CPU, and made both into a list.

`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)
embeddings = embeddings.detach().cpu().tolist()

    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

It threw the following error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
TypeError: list indices must be integers or slices, not tuple

To understand the core problem, I wanted to understand embeddings and attention_maps. So,

`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    
    embeddings = embeddings.detach().cpu().tolist()
    
    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

Details

embeddings: tensor([[[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 1.2959e-01, -2.3578e-01, -9.5074e-01,  ..., -3.4716e-01,
            3.9048e-01, -7.9039e-01],
          [-1.1861e-01, -8.5111e-01,  1.7778e-01,  ..., -6.4417e-01,
           -1.6268e-01, -7.4019e-01],
          ...,
          [ 1.4825e+00,  1.0562e+00, -5.5296e-01,  ...,  4.6048e-02,
           -5.8749e-01,  3.5935e-01],
          [ 1.1087e+00,  8.3452e-01, -4.6560e-01,  ..., -6.5979e-01,
            7.0711e-02,  1.3638e+00],
          [ 7.1583e-01,  8.4463e-01,  7.4550e-01,  ...,  5.5646e-01,
           -6.0864e-01,  1.2408e+00]],

         [[ 8.2428e-01, -6.0705e-01, -9.0634e-01,  ..., -4.5286e-02,
           -6.8834e-02,  4.4105e-01],
          [ 7.2001e-01,  6.3411e-01, -1.0107e+00,  ..., -4.3047e-01,
           -5.7251e-01, -6.7011e-01],
          [ 4.6859e-01, -8.5742e-01, -1.5053e-02,  ..., -2.8734e-01,
           -1.0233e+00, -3.6219e-01],
          ...,
          [ 1.0764e+00,  1.1695e+00, -6.8277e-01,  ...,  2.8122e-02,
           -9.8832e-01,  1.4659e-01],
          [ 8.8104e-01,  1.1147e+00, -7.1646e-01,  ..., -1.0783e-01,
           -7.9473e-01,  1.0538e+00],
          [ 6.3558e-01,  9.0190e-01,  4.0055e-01,  ...,  3.1800e-01,
           -1.0868e+00,  9.7025e-01]],

         [[ 9.6156e-01, -9.6647e-01, -1.4004e+00,  ..., -6.3557e-01,
            4.1958e-01, -1.8568e-01],
          [ 3.0844e-01,  1.0339e+00, -1.5486e+00,  ...,  2.1584e-01,
           -3.8619e-01, -8.9405e-01],
          [ 4.5382e-01, -3.8623e-01,  1.7961e-01,  ..., -1.4155e-01,
           -1.1880e+00, -5.4827e-01],
          ...,
          [ 9.9114e-01,  5.7983e-01, -2.9399e-01,  ..., -4.6010e-01,
           -6.7488e-01, -6.2466e-01],
          [ 7.5153e-01,  4.8691e-01, -5.4032e-01,  ...,  2.6127e-01,
           -1.0607e+00,  7.8277e-01],
          [ 8.5168e-01,  4.9293e-01, -2.6708e-01,  ...,  3.8526e-01,
           -1.1824e+00,  8.5203e-01]],

         ...,

         [[ 1.2814e+00, -4.3900e-01, -3.2785e-01,  ..., -1.2414e+00,
           -6.3775e-01, -1.3176e+00],
          [ 3.0157e-01,  1.6172e+00, -1.3343e+00,  ..., -1.2285e+00,
           -5.5167e-01, -1.8283e+00],
          [ 3.5919e-01, -2.6482e-01, -1.0645e+00,  ..., -4.3375e-02,
           -3.2065e-01, -9.8966e-01],
          ...,
          [ 1.8181e+00, -1.6646e-01, -1.2666e+00,  ...,  1.0637e+00,
            1.4646e+00, -1.6298e+00],
          [ 1.0763e+00, -5.1882e-01, -6.8510e-01,  ...,  1.3576e+00,
            1.2688e+00, -1.4657e+00],
          [ 1.7986e+00, -7.4009e-02, -1.2577e+00,  ...,  1.0660e+00,
            1.4812e+00, -1.4051e+00]],

         [[ 1.2025e+00, -5.5392e-01, -1.0193e+00,  ..., -8.1229e-01,
           -2.3811e-01, -4.7275e-01],
          [ 6.5538e-01,  1.1917e+00, -5.2697e-01,  ..., -8.7801e-01,
           -7.4126e-01, -1.9144e+00],
          [ 2.5875e-01, -7.9232e-01, -8.5029e-01,  ...,  6.4324e-02,
           -8.0997e-02, -1.9687e+00],
          ...,
          [ 1.4830e+00, -1.9244e-01, -6.8066e-01,  ...,  2.1269e-01,
            1.0873e+00, -1.3896e+00],
          [ 5.3997e-01, -1.4820e-01, -2.0483e-01,  ...,  7.3495e-01,
            8.6871e-01, -1.3526e+00],
          [ 1.6477e+00, -5.3092e-02, -7.1276e-01,  ...,  3.2879e-01,
            1.1778e+00, -9.6469e-01]],

         [[ 1.5494e+00, -9.5254e-01, -8.3588e-01,  ..., -4.2762e-01,
            6.2013e-01,  1.0120e-02],
          [ 4.4904e-02,  7.8505e-01, -1.0384e+00,  ..., -7.8334e-02,
           -1.7476e-01, -1.6311e+00],
          [ 1.7894e-01, -9.9010e-01, -1.1633e+00,  ...,  6.0122e-01,
           -1.0615e-01, -1.5358e+00],
          ...,
          [ 1.2771e+00, -1.8352e-01, -1.4466e+00,  ..., -6.2605e-01,
            1.2011e+00, -2.0856e+00],
          [ 5.6284e-01, -9.5801e-02, -1.1209e+00,  ..., -5.1828e-01,
            4.9442e-01, -1.5956e+00],
          [ 1.1071e+00,  3.0336e-01, -1.8048e+00,  ..., -3.8724e-01,
            1.1147e+00, -1.5361e+00]]],


        [[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 5.0035e-01,  5.4549e-01,  3.4283e-01,  ..., -3.0739e-01,
           -4.9315e-01, -1.1373e+00],
          [-4.0275e-01,  2.1443e-02,  2.0123e-01,  ..., -2.4489e-01,
            8.3188e-01, -6.5645e-01],
          ...,
          [ 4.0514e-01, -3.2213e-01,  3.7994e-01,  ...,  1.2408e-01,
            6.3095e-01,  9.2037e-03],
          [ 1.9132e-01, -4.4131e-01,  4.2406e-01,  ..., -2.6266e-01,
            9.8391e-01,  5.5734e-01],
          [ 4.0278e-01, -4.9534e-02,  3.3810e-01,  ...,  1.4354e-01,
            8.4249e-01,  4.0723e-01]],

         [[ 6.2418e-02, -6.1317e-01, -1.5439e+00,  ..., -3.1803e-01,
           -2.0041e-01,  4.4618e-01],
          [-6.7039e-02,  1.2193e+00, -5.0822e-01,  ...,  3.5469e-01,
            2.6262e-02, -7.7125e-01],
          [-9.5805e-01,  1.4456e-01, -1.8127e-01,  ...,  3.6328e-01,
            1.4936e+00, -4.5747e-02],
          ...,
          [ 6.8287e-02,  8.2539e-01,  5.4192e-02,  ..., -1.1069e-01,
            6.6216e-01,  7.4946e-01],
          [-1.9581e-01,  6.8329e-01, -2.6928e-01,  ..., -7.0956e-01,
            7.8344e-01,  1.4804e+00],
          [-4.1462e-02,  8.8683e-01, -5.2905e-01,  ..., -2.5274e-01,
            7.1604e-01,  1.2256e+00]],

         [[ 3.5130e-01, -1.5874e+00, -1.7016e+00,  ...,  6.8850e-01,
           -5.8646e-01,  1.7784e-01],
          [ 1.1386e-01,  1.3657e+00, -8.2388e-01,  ...,  4.7490e-01,
            1.2626e+00, -3.1313e-01],
          [-1.1854e+00, -1.1600e-03, -7.3433e-01,  ...,  7.6139e-01,
            1.6375e+00,  1.8955e-01],
          ...,
          [-6.9969e-01,  1.1508e+00,  7.0558e-02,  ...,  4.2873e-01,
            5.6067e-01,  5.2250e-01],
          [-5.0788e-01,  6.6331e-01, -6.1032e-01,  ..., -2.3532e-01,
            8.2221e-01,  7.9204e-01],
          [-2.6820e-01,  8.5643e-01, -4.7090e-01,  ..., -2.8118e-01,
            6.5296e-01,  6.8785e-01]],

         ...,

         [[-9.0217e-02, -2.6741e-01, -1.0890e+00,  ...,  1.8798e+00,
           -3.2522e-03, -1.5653e-01],
          [-6.9740e-01,  1.4951e+00, -6.4886e-01,  ..., -1.3687e-01,
            1.4956e+00,  3.7487e-01],
          [-1.6580e-01,  1.1264e-01, -7.6442e-01,  ...,  4.3402e-01,
            1.9541e+00,  1.2029e+00],
          ...,
          [ 1.9953e-01,  2.6025e+00, -4.9651e-01,  ...,  5.0344e-01,
           -1.2114e-02,  3.9688e-01],
          [-1.0917e+00,  1.2115e+00,  6.2053e-01,  ...,  8.5435e-01,
           -4.5358e-02,  3.5120e-01],
          [ 6.1694e-01,  2.1130e+00, -1.1016e+00,  ...,  2.8187e-01,
            9.5419e-02, -3.5959e-01]],

         [[ 5.0400e-01, -5.3220e-01, -1.0173e+00,  ...,  2.1676e+00,
           -3.6843e-01, -1.8500e-01],
          [-2.1364e-01,  9.2027e-01, -2.5382e-01,  ...,  1.1757e-01,
            9.4363e-01,  6.0816e-01],
          [-1.0163e-01, -3.2413e-02, -7.2567e-01,  ...,  1.1070e+00,
            1.3306e+00,  1.0462e+00],
          ...,
          [ 3.0022e-01,  2.6991e+00, -4.7573e-01,  ..., -1.0428e-01,
           -7.8721e-02,  1.1695e+00],
          [-1.0961e+00,  6.7808e-01, -3.0792e-01,  ...,  7.1660e-01,
           -2.0900e-01,  4.5738e-01],
          [ 8.4948e-01,  1.9340e+00, -1.1624e+00,  ..., -2.2008e-01,
            4.5761e-01,  9.6474e-01]],

         [[-2.9150e-01,  4.8298e-01, -3.7572e-01,  ...,  2.4827e+00,
           -1.9686e-01,  2.9108e-01],
          [-4.5003e-01,  4.0321e-01, -1.0218e+00,  ..., -1.9378e-01,
            5.3391e-01,  3.8499e-01],
          [-7.7064e-02, -5.0206e-01, -1.3377e+00,  ...,  9.1953e-01,
            5.2488e-01,  1.2372e-01],
          ...,
          [ 7.2962e-01,  1.8133e+00,  2.9414e-01,  ...,  7.3038e-01,
           -2.0271e-01,  2.1481e+00],
          [-7.7066e-01, -1.0586e-01, -9.3787e-02,  ...,  1.0239e+00,
           -2.1658e-01,  9.3203e-01],
          [ 1.0556e+00,  9.7592e-01, -1.2148e+00,  ..., -4.7689e-02,
           -1.4709e-02,  2.9145e-01]]]], device='cuda:0') length:torch.Size([2, 9, 120, 512])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) length:torch.Size([2, 120])

I made this change:
`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    embeddings = embeddings.detach().cpu()
    
    for i, a in enumerate(attention_mask.detach().cpu()):
        embeddings[i] = embeddings[i][:, a == 1]

`

So, finally, I tried to replace them as tensors and tried to replace, but it obviously threw tensor dimensions mismatch error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: The expanded size of the tensor (120) must match the existing size (109) at non-singleton dimension 1.  Target sizes: [9, 120, 512].  Tensor sizes: [9, 109, 512]

Because embeddings size is: [2, 9, 120, 512].
Whereas attention_mask size is: [2, 120].

What is the end goal of the following snippet? Why does this throw an error? Please help me resolv this.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions