Skip to content

Commit

Permalink
Adds attention animation
Browse files Browse the repository at this point in the history
  • Loading branch information
andyljones committed Jan 10, 2021
1 parent 7fb2aa1 commit 8424ab6
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 34 deletions.
6 changes: 3 additions & 3 deletions experiments/attn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def forward(self, x, b):
attn = torch.softmax(dots, -2)
vals = torch.einsum('bph,bphd->bhd', attn, v)

return F.relu(self.final(vals.view(B, H*Dx))), attn
return F.relu(self.final(vals.view(B, H*Dx))), attn.detach()

class ReZeroAttn(nn.Module):

Expand All @@ -164,7 +164,7 @@ def forward(self, x, b):

class AttnModel(nn.Module):

def __init__(self, Head, boardsize, D, n_layers=8):
def __init__(self, Head, boardsize, D, n_layers=8, n_heads=1):
super().__init__()

pos = positions(boardsize)
Expand All @@ -176,7 +176,7 @@ def __init__(self, Head, boardsize, D, n_layers=8):
self.D = D
layers = []
for _ in range(n_layers):
layers.append(ReZeroAttn(D, D_prep))
layers.append(ReZeroAttn(D, D_prep, H=n_heads))
self.layers = nn.ModuleList(layers)

self.head = Head(D, pos.shape[-1])
Expand Down
46 changes: 36 additions & 10 deletions experiments/attn/victory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from torch.nn import functional as F
from boardlaw.main import worldfunc, agentfunc
from boardlaw.hex import Hex
from boardlaw.hex import plot_board
import matplotlib.pyplot as plt
from rebar.recording import Encoder


def gamegen(worlds):
while True:
Expand Down Expand Up @@ -45,10 +49,7 @@ def terminal_actions(worlds):
terminal[mask, a] = transitions.terminal
return terminal.float()

def plot(i, obs, targets, outputs, attns):
import matplotlib.pyplot as plt
from boardlaw.hex import plot_board

def plot_soln(i, obs, targets, outputs, attns=None):
boardsize = obs.shape[-2]

fig, axes = plt.subplots(2, 2)
Expand All @@ -68,10 +69,35 @@ def plot(i, obs, targets, outputs, attns):
ax.set_title('board')

ax = axes[1, 1]
attn = attns.detach().cpu().numpy()[i].max(0).max(-1)
attn = attn.reshape(boardsize, boardsize)
plot_board(plt.cm.viridis(attn/attn.max()), ax=ax)
ax.set_title('attn')
if attns is None:
ax.axis('off')
else:
attn = attns.detach().cpu().numpy()[i].max(0).max(-1)
attn = attn.reshape(boardsize, boardsize)
plot_board(plt.cm.viridis(attn/attn.max()), ax=ax)
ax.set_title('attn')

def animate(i, obs, attns):
obs = obs[i].detach().cpu().numpy()
attn = attns[i].detach().cpu().numpy()
boardsize = int(attn.shape[1]**.5)
attn = attn.transpose(0, 2, 1).reshape(attn.shape[0], attn.shape[2], boardsize, boardsize)

rows, cols = attn.shape[:2]

with Encoder(fps=1) as enc:
for r in range(rows+3):
r = min(r, rows-1)
fig, axes = plt.subplots(1, cols+1, squeeze=False)
fig.set_size_inches(8*(cols+1), 8)

plot_board(color_obs(obs), ax=axes[0, 0])
for c in range(cols):
colors = plt.cm.viridis(attn[r, c])
plot_board(colors, ax=axes[0, c+1])
enc(fig)
plt.close(fig)
enc.notebook()

def run(D=32, B=8*1024, T=5000, device='cuda'):

Expand All @@ -87,8 +113,8 @@ def run(D=32, B=8*1024, T=5000, device='cuda'):
targets = terminal_actions(b.worlds)
outputs = model(b.worlds.obs)

infs = torch.full_like(targets, np.inf)
loss = -outputs.reshape(B, -1).where(targets == 1., infs).min(-1).values.mean()
infs = torch.full_like(targets, -np.inf)
loss = -outputs.reshape(B, -1).where(targets == 1., infs).max(-1).values.mean()

opt.zero_grad()
loss.backward()
Expand Down
49 changes: 31 additions & 18 deletions main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"outputs": [],
"source": [
"device = 'cuda'\n",
"D = 32\n",
"D = 64\n",
"B = 8*1024\n",
"T = 5000\n",
"T = 500\n",
"\n",
"worlds = Hex.initial(B, boardsize=7)\n",
"\n",
Expand All @@ -32,8 +32,8 @@
" targets = terminal_actions(b.worlds)\n",
" outputs = model(b.worlds.obs)\n",
"\n",
" infs = torch.full_like(targets, np.inf)\n",
" loss = -outputs.reshape(B, -1).where(targets == 1., infs).min(-1).values.mean()\n",
" infs = torch.full_like(targets, -np.inf)\n",
" loss = -outputs.reshape(B, -1).where(targets == 1., infs).max(-1).values.mean()\n",
"\n",
" opt.zero_grad()\n",
" loss.backward()\n",
Expand All @@ -53,7 +53,8 @@
"targets = terminal_actions(b.worlds)\n",
"outputs = model(b.worlds.obs)\n",
"\n",
"loss = -outputs.reshape(B, -1).mul(targets).sum(-1)\n",
"infs = torch.full_like(targets, -np.inf)\n",
"loss = -outputs.reshape(B, -1).where(targets == 1., infs).max(-1).values\n",
"\n",
"bad = loss.argsort().cpu().numpy()[::-1]"
]
Expand All @@ -64,7 +65,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot(bad[-1], b.worlds.obs, targets, outputs, model.attns)"
"plot(bad[10], b.worlds.obs, targets, outputs, model.attns)"
]
},
{
Expand All @@ -73,7 +74,30 @@
"metadata": {},
"outputs": [],
"source": [
"loss.detach().cpu().numpy()[bad]"
"from rebar.recording import Encoder\n",
"\n",
"i = 11\n",
"obs = b.worlds.obs[i].detach().cpu().numpy()\n",
"attn = model.attns[i].detach().cpu().numpy()\n",
"boardsize = int(attn.shape[1]**.5)\n",
"attn = attn.transpose(0, 2, 1).reshape(attn.shape[0], attn.shape[2], boardsize, boardsize)\n",
"\n",
"rows, cols = attn.shape[:2]\n",
"\n",
"\n",
"with Encoder(fps=1) as enc:\n",
" for r in range(rows+3):\n",
" r = min(r, rows-1)\n",
" fig, axes = plt.subplots(1, cols+1, squeeze=False)\n",
" fig.set_size_inches(8*(cols+1), 8)\n",
"\n",
" plot_board(color_obs(obs), ax=axes[0, 0])\n",
" for c in range(cols):\n",
" colors = plt.cm.viridis(attn[r, c])\n",
" plot_board(colors, ax=axes[0, c+1])\n",
" enc(fig)\n",
" plt.close(fig)\n",
"enc.notebook()"
]
}
],
Expand All @@ -82,17 +106,6 @@
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
Expand Down
9 changes: 6 additions & 3 deletions rebar/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __call__(self, arr):
if isinstance(arr, plt.Figure):
fig = arr
arr = array(fig)
fig.gcf()

if not self._initialized:
self._initialize(arr)
Expand All @@ -102,18 +101,22 @@ def __exit__(self, type, value, traceback):
self._container.mux(self._stream.encode())
self._container.close()
self.value = self._content.getvalue()

def notebook(self):
return notebook(self.value)


def html_tag(video, height=None, **kwargs):
video = video.value if isinstance(video, Encoder) else video
style = f'style="height: {height}px"' if height else ''
style = f'style="height: {height}px"' if height else 'style="height: 100%; width: 100%"'
b64 = base64.b64encode(video).decode('utf-8')
return f"""
<video controls autoplay loop {style}>
<source type="video/mp4" src="data:video/mp4;base64,{b64}">
Your browser does not support the video tag.
</video>"""

def notebook(video, height=640):
def notebook(video, height=None):
from IPython.display import display, HTML
return display(HTML(html_tag(video, height)))

Expand Down

0 comments on commit 8424ab6

Please sign in to comment.