diff --git a/src/model/fearec.py b/src/model/fearec.py index ca3ebde..e4a724b 100644 --- a/src/model/fearec.py +++ b/src/model/fearec.py @@ -170,7 +170,7 @@ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False class FEARecBlock(nn.Module): def __init__(self, args, layer_num): super(FEARecBlock, self).__init__() - self.layer = FEARecLayer(args) + self.layer = FEARecLayer(args, layer_num) self.feed_forward = FeedForward(args) def forward(self, hidden_states, attention_mask):