-
Notifications
You must be signed in to change notification settings - Fork 32
/
switch_base.gin
47 lines (38 loc) · 1.46 KB
/
switch_base.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Switch Transformer Base model.
#
# Based on the original Switch Transformer (https://arxiv.org/abs/2101.03961).
#
# Note that unlike the original Switch Transformer, this T5X version does not
# use any jitter noise in the router.
#
# Provides MODEL and NUM_EXPERTS.
from __gin__ import dynamic_registration
from flaxformer.architectures.moe import moe_architecture
from flaxformer.architectures.moe import moe_layers
from flaxformer.architectures.moe import routing
from flaxformer.components import dense
import seqio
from t5x import adafactor
ARCHITECTURE = %gin.REQUIRED
include 'flaxformer/t5x/configs/moe/models/tokens_choose_base.gin'
# Architecture overrides
MLP_DIM = 3072
# MoE overrides
NUM_EXPERTS = 128
# Replace every other MLP sublayer is an MoE sublayer.
NUM_ENCODER_SPARSE_LAYERS = 6
NUM_DECODER_SPARSE_LAYERS = 6
TRAIN_EXPERT_CAPACITY_FACTOR = 1.25
EVAL_EXPERT_CAPACITY_FACTOR = 2.
NUM_SELECTED_EXPERTS = 1 # Switch routing
AUX_LOSS_FACTOR = 0.01
ROUTER_Z_LOSS_FACTOR = 0.0
GROUP_SIZE = 8192
# Switch Transformer Base uses relu activations.
dense.MlpBlock.activations = ('relu',)
expert/dense.MlpBlock.activations = ('relu',)
# Switch Transformer Base re-uses the token embedder to compute output logits.
moe_architecture.SparseDecoder.output_logits_factory = None
# Switch Transformer doesn't use BPR in encoder (although most sparse encoders
# generally see a boost from it).
sparse_encoder/routing.TokensChooseMaskedRouter.batch_prioritized_routing = False