Skip to content

Commit

Permalink
update doc for simpgcn
Browse files Browse the repository at this point in the history
  • Loading branch information
ChandlerBang committed Mar 11, 2021
1 parent 000f86d commit 877ef31
Show file tree
Hide file tree
Showing 15 changed files with 436 additions and 67 deletions.
81 changes: 50 additions & 31 deletions deeprobust/graph/defense/simpgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,60 @@


class SimPGCN(nn.Module):
"""SimP-GCN: Node similarity preserving graph convolutional networks.
https://arxiv.org/abs/2011.09643
Parameters
----------
nnodes : int
number of nodes in the input grpah
nfeat : int
size of input feature dimension
nhid : int
number of hidden units
nclass : int
size of output dimension
lambda_ : float
coefficients for SSL loss in SimP-GCN
gamma : float
coefficients for adaptive learnable self-loops
bias_init : float
bias init for the score
dropout : float
dropout rate for GCN
lr : float
learning rate for GCN
weight_decay : float
weight decay coefficient (l2 normalization) for GCN. When `with_relu` is True, `weight_decay` will be set to 0.
with_bias: bool
whether to include bias term in GCN weights.
device: str
'cpu' or 'cuda'.
Examples
--------
We can first load dataset and then train SimPGCN.
See the detailed hyper-parameter setting in https://github.com/ChandlerBang/SimP-GCN.
>>> from deeprobust.graph.data import PrePtbDataset, Dataset
>>> from deeprobust.graph.defense import SimPGCN
>>> # load clean graph data
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
>>> adj, features, labels = data.adj, data.features, data.labels
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
>>> # load perturbed graph data
>>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora')
>>> perturbed_adj = perturbed_data.adj
>>> model = SimPGCN(nnodes=features.shape[0], nfeat=features.shape[1],
nhid=16, nclass=labels.max()+1, device='cuda')
>>> model = model.to('cuda')
>>> model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
>>> model.test(idx_test)
"""

def __init__(self, nnodes, nfeat, nhid, nclass, dropout=0.5, lr=0.01,
weight_decay=5e-4, lambda_=5, gamma=0.1, bias_init=0,
with_bias=True, device=None):
"""SimP-GCN: Node similarity preserving graph convolutional networks.
https://arxiv.org/abs/2011.09643
Parameters
----------
nnodes : int
number of nodes in the input grpah
nfeat : int
size of input feature dimension
nhid : int
number of hidden units
nclass : int
size of output dimension
lambda_ : float
coefficients for SSL loss in SimP-GCN
gamma : float
coefficients for adaptive learnable self-loops
bias_init : float
bias init for the score
dropout : float
dropout rate for GCN
lr : float
learning rate for GCN
weight_decay : float
weight decay coefficient (l2 normalization) for GCN. When `with_relu` is True, `weight_decay` will be set to 0.
with_bias: bool
whether to include bias term in GCN weights.
device: str
'cpu' or 'cuda'.
"""

super(SimPGCN, self).__init__()

assert device is not None, "Please specify 'device'!"
Expand Down
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/_build/doctrees/source/deeprobust.graph.defense.doctree
Binary file not shown.
2 changes: 2 additions & 0 deletions docs/_build/html/_modules/deeprobust/graph/defense/gcn.html
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ <h1>Source code for deeprobust.graph.defense.gcn</h1><div class="highlight"><pre
<span class="bp">self</span><span class="o">.</span><span class="n">adj_norm</span> <span class="o">=</span> <span class="n">utils</span><span class="o">.</span><span class="n">normalize_adj_tensor</span><span class="p">(</span><span class="n">adj</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">adj_norm</span><span class="p">)</span></div></div>



</pre></div>

</div>
Expand Down

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions docs/_build/html/_modules/deeprobust/graph/defense/r_gcn.html
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,17 @@ <h1>Source code for deeprobust.graph.defense.r_gcn</h1><div class="highlight"><p
<span class="s2">&quot;accuracy= </span><span class="si">{:.4f}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">acc_test</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span>
<span class="k">return</span> <span class="n">acc_test</span><span class="o">.</span><span class="n">item</span><span class="p">()</span></div>

<div class="viewcode-block" id="RGCN.predict"><a class="viewcode-back" href="../../../../source/deeprobust.graph.defense.html#deeprobust.graph.defense.r_gcn.RGCN.predict">[docs]</a> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> torch.FloatTensor</span>
<span class="sd"> output (log probabilities) of RGCN</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="bp">self</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward</span><span class="p">()</span></div>

<span class="k">def</span> <span class="nf">_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="n">miu1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gc1</span><span class="o">.</span><span class="n">miu</span>
Expand All @@ -490,6 +501,30 @@ <h1>Source code for deeprobust.graph.defense.r_gcn</h1><div class="highlight"><p
<span class="n">D_power</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diag</span><span class="p">(</span><span class="n">D_power</span><span class="p">)</span>
<span class="k">return</span> <span class="n">D_power</span> <span class="o">@</span> <span class="n">A</span> <span class="o">@</span> <span class="n">D_power</span></div>

<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>

<span class="kn">from</span> <span class="nn">deeprobust.graph.data</span> <span class="kn">import</span> <span class="n">PrePtbDataset</span><span class="p">,</span> <span class="n">Dataset</span>
<span class="c1"># load clean graph data</span>
<span class="n">dataset_str</span> <span class="o">=</span> <span class="s1">&#39;pubmed&#39;</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;/tmp/&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">dataset_str</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">15</span><span class="p">)</span>
<span class="n">adj</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">adj</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">labels</span>
<span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span> <span class="n">idx_test</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_train</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_val</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">idx_test</span>
<span class="c1"># load perturbed graph data</span>
<span class="n">perturbed_data</span> <span class="o">=</span> <span class="n">PrePtbDataset</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s1">&#39;/tmp/&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">dataset_str</span><span class="p">)</span>
<span class="n">perturbed_adj</span> <span class="o">=</span> <span class="n">perturbed_data</span><span class="o">.</span><span class="n">adj</span>

<span class="c1"># train defense model</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">RGCN</span><span class="p">(</span><span class="n">nnodes</span><span class="o">=</span><span class="n">perturbed_adj</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">nfeat</span><span class="o">=</span><span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
<span class="n">nclass</span><span class="o">=</span><span class="n">labels</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">nhid</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">perturbed_adj</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">idx_train</span><span class="p">,</span> <span class="n">idx_val</span><span class="p">,</span>
<span class="n">train_iters</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">test</span><span class="p">(</span><span class="n">idx_test</span><span class="p">)</span>

<span class="n">prediction_1</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">prediction_1</span><span class="p">)</span>
<span class="c1"># prediction_2 = model.predict(features, perturbed_adj)</span>
<span class="c1"># assert (prediction_1 != prediction_2).sum() == 0</span>

</pre></div>

</div>
Expand Down
6 changes: 2 additions & 4 deletions docs/_build/html/_modules/deeprobust/image/attack/pgd.html
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,8 @@ <h1>Source code for deeprobust.image.attack.pgd</h1><div class="highlight"><pre>
<span class="n">X_pgd</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">data</span> <span class="o">+</span> <span class="n">eta</span>

<span class="c1">#X_pgd = (torch.clamp(X_pgd * std + mean, clip_min, clip_max) - mean) / std</span>

<span class="n">X_pgd</span><span class="p">[:,</span><span class="mi">0</span><span class="p">,:,:]</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">X_pgd</span><span class="p">[:,</span><span class="mi">0</span><span class="p">,:,:]</span> <span class="o">*</span> <span class="n">std</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">mean</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">clip_min</span><span class="p">,</span> <span class="n">clip_max</span><span class="p">)</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">/</span> <span class="n">std</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">X_pgd</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,:,:]</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">X_pgd</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,:,:]</span> <span class="o">*</span> <span class="n">std</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">mean</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">clip_min</span><span class="p">,</span> <span class="n">clip_max</span><span class="p">)</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="o">/</span> <span class="n">std</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">X_pgd</span><span class="p">[:,</span><span class="mi">2</span><span class="p">,:,:]</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">X_pgd</span><span class="p">[:,</span><span class="mi">2</span><span class="p">,:,:]</span> <span class="o">*</span> <span class="n">std</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">mean</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">clip_min</span><span class="p">,</span> <span class="n">clip_max</span><span class="p">)</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="o">/</span> <span class="n">std</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="k">for</span> <span class="n">ind</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">X_pgd</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="n">X_pgd</span><span class="p">[:,</span><span class="n">ind</span><span class="p">,:,:]</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">X_pgd</span><span class="p">[:,</span><span class="n">ind</span><span class="p">,:,:]</span> <span class="o">*</span> <span class="n">std</span><span class="p">[</span><span class="n">ind</span><span class="p">]</span> <span class="o">+</span> <span class="n">mean</span><span class="p">[</span><span class="n">ind</span><span class="p">],</span> <span class="n">clip_min</span><span class="p">,</span> <span class="n">clip_max</span><span class="p">)</span> <span class="o">-</span> <span class="n">mean</span><span class="p">[</span><span class="n">ind</span><span class="p">])</span> <span class="o">/</span> <span class="n">std</span><span class="p">[</span><span class="n">ind</span><span class="p">]</span>

<span class="n">X_pgd</span> <span class="o">=</span> <span class="n">X_pgd</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">X_pgd</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span>
Expand Down
Loading

0 comments on commit 877ef31

Please sign in to comment.