Skip to content

Commit

Permalink
Change so that does not required deepspeed to be installed when only …
Browse files Browse the repository at this point in the history
…performing inference on sparse transformer
  • Loading branch information
wilson1yan committed Jun 10, 2021
1 parent 3ca4c06 commit 7d1f20e
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions videogpt/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,6 @@ def __init__(self, shape, dim_q, dim_kv, n_head, n_layer,
assert not causal, 'causal axial attention is not supported'
self.attn = AxialAttention(len(shape), **attn_kwargs)
elif attn_type == 'sparse':
try:
from deepspeed.ops.sparse_attention import MatMul, Softmax
except:
raise Exception('Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`')
self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs)

self.cache = None
Expand Down Expand Up @@ -264,15 +260,17 @@ def __init__(self, shape, n_head, causal, num_local_blocks=4, block=32,
self.sparsity_config = StridedSparsityConfig(shape=shape, n_head=n_head,
causal=causal, block=block,
num_local_blocks=num_local_blocks)
self.get_ops()

if self.shape not in SparseAttention.block_layout:
SparseAttention.block_layout[self.shape] = self.sparsity_config.make_layout()
if causal and self.shape not in SparseAttention.attn_mask:
SparseAttention.attn_mask[self.shape] = self.sparsity_config.make_sparse_attn_mask()

def get_ops(self):
from deepspeed.ops.sparse_attention import MatMul, Softmax
try:
from deepspeed.ops.sparse_attention import MatMul, Softmax
except:
raise Exception('Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`')
if self.shape not in SparseAttention.ops:
sparsity_layout = self.sparsity_config.make_layout()
sparse_dot_sdd_nt = MatMul(sparsity_layout,
Expand All @@ -295,6 +293,9 @@ def get_ops(self):
return SparseAttention.ops[self.shape]

def forward(self, q, k, v, decode_step, decode_idx):
if self.training and self.shape not in SparseAttention.ops:
self.get_ops()

SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[self.shape].to(q)
if self.causal:
SparseAttention.attn_mask[self.shape] = SparseAttention.attn_mask[self.shape].to(q).type_as(q)
Expand Down

0 comments on commit 7d1f20e

Please sign in to comment.