diff --git a/main.py b/main.py old mode 100755 new mode 100644 diff --git a/src/bert_layers/configuration_bert.py b/src/bert_layers/configuration_bert.py index 6fbdeb53..b9f1e6ce 100644 --- a/src/bert_layers/configuration_bert.py +++ b/src/bert_layers/configuration_bert.py @@ -58,7 +58,7 @@ def __init__( loss_kwargs: dict = {}, mlp_dropout_prob: float = 0.0, mlp_in_bias: bool = False, - mlp_layer: str = "mlp", + mlp_layer: str = "glu_moe", mlp_out_bias: bool = False, norm_kwargs: dict = {}, normalization: str = "rmsnorm", @@ -97,6 +97,13 @@ def __init__( pad_logits: bool = False, compile_model: bool = False, masked_prediction: bool = False, + moe_num_experts: int = 8, + moe_top_k: int = 2, + moe_use_noisy_top_k: bool = True, + moe_capacity_factor: float = 1.25, + moe_compute_aux_loss: bool = True, + moe_load_balance_loss_weight: float = 0.01, + moe_router_z_loss_weight: float = 0.001, **kwargs, ): """ @@ -156,6 +163,13 @@ def __init__( pad_logits (bool): Pad logits after the calculating the loss. compile_model (bool): Compile the subset of the model which can be compiled. masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers + moe_num_experts (int): Number of experts for Mixture of Experts layers. + moe_top_k (int): Number of top experts to select for each token in MoE. + moe_use_noisy_top_k (bool): Use noisy top-k gating for exploration during training. + moe_capacity_factor (float): Capacity factor for expert assignment in MoE. + moe_compute_aux_loss (bool): Whether to compute and add auxiliary losses for MoE. + moe_load_balance_loss_weight (float): Weight for the load balancing auxiliary loss. + moe_router_z_loss_weight (float): Weight for the router z-loss auxiliary loss. **kwargs: Additional keyword arguments. """ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) @@ -213,6 +227,13 @@ def __init__( self.pad_logits = pad_logits self.compile_model = compile_model self.masked_prediction = masked_prediction + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_use_noisy_top_k = moe_use_noisy_top_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_compute_aux_loss = moe_compute_aux_loss + self.moe_load_balance_loss_weight = moe_load_balance_loss_weight + self.moe_router_z_loss_weight = moe_router_z_loss_weight if loss_kwargs.get("return_z_loss", False): if loss_function != "fa_cross_entropy": diff --git a/src/bert_layers/loss.py b/src/bert_layers/loss.py index 8b0007cf..3f210a4c 100644 --- a/src/bert_layers/loss.py +++ b/src/bert_layers/loss.py @@ -2,7 +2,9 @@ # License: Apache-2.0 import inspect +import torch import torch.nn as nn +import torch.nn.functional as F from .configuration_bert import FlexBertConfig try: @@ -20,6 +22,79 @@ LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss +class MoELoadBalancingLoss(nn.Module): + """Computes Switch Transformer auxiliary loss for load balancing. + + Reference: https://arxiv.org/abs/2101.03961 (equations 4-6, page 7) + + This loss encourages balanced token allocation across experts to avoid + scenarios where some experts are overloaded while others are underutilized. + """ + + def __init__(self, num_experts: int, top_k: int = 2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + + def forward( + self, + router_logits: torch.Tensor, + expert_indices: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + router_logits: Router logits [batch_size, seq_len, num_experts] + expert_indices: Top-k expert indices [batch_size, seq_len, top_k] + + Returns: + load_balance_loss: Scalar loss value + """ + # Compute expert probabilities + expert_probs = F.softmax(router_logits, dim=-1) # [B, C, n_exp] + + # Equation (5): compute ratio of tokens allocated to each expert + with torch.no_grad(): + one_hot_indices = F.one_hot(expert_indices, num_classes=self.num_experts) # [B, C, K, n_exp] + one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, C, n_exp] + tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1)) # [n_exp] + + # Equation (6): compute ratio of router probability allocated to each expert + prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1)) # [n_exp] + + # Equation (4): scaled dot product between prob / token allocation vectors + load_balance_loss = self.num_experts * torch.sum(prob_per_expert * tokens_per_expert) + + return load_balance_loss + + +class MoERouterZLoss(nn.Module): + """Computes router z-loss for MoE models. + + Reference: https://arxiv.org/abs/2202.08906 (equation 5, page 7) + + This loss constrains the size of router logits to avoid numerical instability + during training. Large logits can lead to round-off errors in the softmax computation, + even in float32 precision. + """ + + def forward(self, router_logits: torch.Tensor) -> torch.Tensor: + """ + Args: + router_logits: Router logits [batch_size, seq_len, num_experts] + + Returns: + router_z_loss: Scalar loss value + """ + # Numerically stable computation: logsumexp is equivalent to log(sum(exp(x))) + # This avoids overflow issues from directly exponentiating large logits + router_z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, C] + + # Average over all tokens + router_z_loss = torch.mean(router_z_loss) + + return router_z_loss + + def get_loss_fn(config: FlexBertConfig) -> nn.Module: try: loss_class = LOSS2CLS[config.loss_function] diff --git a/src/bert_layers/mlp.py b/src/bert_layers/mlp.py index 349d559b..f344008d 100644 --- a/src/bert_layers/mlp.py +++ b/src/bert_layers/mlp.py @@ -15,11 +15,13 @@ import torch import torch.nn as nn +import torch.nn.functional as F from .configuration_bert import FlexBertConfig from .activation import get_act_fn from .normalization import get_norm_layer from .initialization import ModuleType, init_weights +from .loss import MoELoadBalancingLoss, MoERouterZLoss class BertResidualGLU(nn.Module): @@ -190,10 +192,197 @@ def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: return self.Wo(self.drop(self.act(input) * gate)) +class Router(nn.Module): + """Top-K router for selecting experts.""" + + def __init__( + self, + d: int, + n_exp: int, + top_k: int = 2, + use_noisy_top_k: bool = True, + capacity_factor: float = 1.25, + ): + super().__init__() + self.d = d + self.n_exp = n_exp + self.top_k = top_k + self.use_noisy_top_k = use_noisy_top_k + self.capacity_factor = capacity_factor + + # Router weights to compute logits for each expert + self.gate = nn.Linear(d, n_exp, bias=False) + + # Noise parameters for noisy top-k gating + if use_noisy_top_k: + self.w_noise = nn.Linear(d, n_exp, bias=False) + + def forward(self, x: torch.Tensor): + """ + Args: + x: [batch_size, seq_len, d] + Returns: + exp_weight: Expert weights [batch_size * seq_len, top_k] + exp_mask: Expert mask [batch_size * seq_len, n_exp, exp_capacity] + exp_batches: Token assignments [n_exp, exp_capacity, d] + """ + B, C, d = x.size() + num_tokens = B * C + x_flat = x.view(num_tokens, d) + + # Compute router logits + logits = self.gate(x_flat) # [num_tokens, n_exp] + + # Add noise for exploration (optional) + if self.use_noisy_top_k and self.training: + noise_stddev = F.softplus(self.w_noise(x_flat)) + noise = torch.randn_like(logits) * noise_stddev + logits = logits + noise + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) + top_k_gates = F.softmax(top_k_logits, dim=-1) # [num_tokens, top_k] + + # Compute expert capacity + exp_capacity = int((num_tokens * self.top_k * self.capacity_factor) / self.n_exp) + + # Create expert assignment mask and batches + exp_mask = torch.zeros(num_tokens, self.n_exp, exp_capacity, device=x.device) + exp_batches = torch.zeros(self.n_exp, exp_capacity, d, device=x.device) + + # Count tokens assigned to each expert + expert_counts = torch.zeros(self.n_exp, dtype=torch.long, device=x.device) + + # Assign tokens to experts + for token_idx in range(num_tokens): + for k_idx in range(self.top_k): + expert_idx = top_k_indices[token_idx, k_idx] + if expert_counts[expert_idx] < exp_capacity: + pos = expert_counts[expert_idx] + exp_mask[token_idx, expert_idx, pos] = top_k_gates[token_idx, k_idx] + exp_batches[expert_idx, pos] = x_flat[token_idx] + expert_counts[expert_idx] += 1 + + return top_k_gates, exp_mask, exp_batches + + + + + +class FlexBertGLUMoE(FlexBertMLPBase): + """Mixture of Experts with GLU activation for FlexBERT.""" + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__(config=config, layer_id=layer_id) + + self.n_exp = getattr(config, 'moe_num_experts', 4) + self.top_k = getattr(config, 'moe_top_k', 2) + self.use_noisy_top_k = getattr(config, 'moe_use_noisy_top_k', True) + self.capacity_factor = getattr(config, 'moe_capacity_factor', 1.25) + self.compute_aux_loss = getattr(config, 'moe_compute_aux_loss', True) + self.load_balance_loss_weight = getattr(config, 'moe_load_balance_loss_weight', 0.01) + self.router_z_loss_weight = getattr(config, 'moe_router_z_loss_weight', 0.001) + + self.router = Router( + d=config.hidden_size, + n_exp=self.n_exp, + top_k=self.top_k, + use_noisy_top_k=self.use_noisy_top_k, + capacity_factor=self.capacity_factor, + ) + + # GLU experts (each projects to 2x intermediate size) + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=config.mlp_in_bias), + nn.Identity(), # Placeholder for chunking + activation + nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity(), + nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias), + ) + for _ in range(self.n_exp) + ]) + self.act = get_act_fn(config.hidden_act) + + # Initialize auxiliary loss modules + if self.compute_aux_loss: + self.load_balance_loss = MoELoadBalancingLoss(num_experts=self.n_exp, top_k=self.top_k) + self.router_z_loss = MoERouterZLoss() + + def _init_weights(self, reset_params: bool = False): + init_weights( + self.config, + self.router.gate, + layer_dim=self.config.hidden_size, + layer_id=self.layer_id, + type_of_module=ModuleType.in_module, + ) + + for expert in self.experts: + for i, module in enumerate(expert): + if isinstance(module, nn.Linear): + init_weights( + self.config, + module, + layer_dim=self.config.hidden_size if i == 0 else self.config.intermediate_size, + layer_id=self.layer_id, + type_of_module=ModuleType.in_module if i == 0 else ModuleType.out_module, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + original_shape = hidden_states.shape + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + B, C, d = hidden_states.size() + num_tokens = B * C + x_flat = hidden_states.view(num_tokens, d) + + # Compute router logits for auxiliary loss calculation + router_logits = self.router.gate(x_flat) # [num_tokens, n_exp] + + exp_weight, exp_mask, exp_batches = self.router(hidden_states) + + # Extract top-k indices from router for load balancing loss + _, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) # [num_tokens, top_k] + + # Apply GLU experts + exp_out = torch.zeros_like(exp_batches) + for i, expert in enumerate(self.experts): + x = expert[0](exp_batches[i]) # Linear projection + input, gate = x.chunk(2, dim=-1) # Split for GLU + x = self.act(input) * gate # GLU activation + x = expert[2](x) # Dropout + exp_out[i] = expert[3](x) # Output projection + + exp_weight_flat = exp_mask.view(num_tokens, -1) + exp_out_flat = exp_out.view(-1, d) + output = torch.matmul(exp_weight_flat, exp_out_flat) + + # Compute auxiliary losses + self.aux_loss = None + if self.compute_aux_loss: + # Reshape for loss computation + router_logits_reshaped = router_logits.view(B, C, -1) + top_k_indices_reshaped = top_k_indices.view(B, C, -1) + + # Compute load balancing loss + lb_loss = self.load_balance_loss(router_logits_reshaped, top_k_indices_reshaped) + + # Compute router z-loss + z_loss = self.router_z_loss(router_logits_reshaped) + + # Combine auxiliary losses with weights + self.aux_loss = self.load_balance_loss_weight * lb_loss + self.router_z_loss_weight * z_loss + + return output.view(*original_shape) + + +# Update the MLP registry MLP2CLS = { "mlp": FlexBertMLP, "glu": FlexBertGLU, "parallel_glu": FlexBertParallelGLU, + "glu_moe": FlexBertGLUMoE, } @@ -212,3 +401,5 @@ def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> Fle ) else: raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") + + diff --git a/src/bert_layers/model.py b/src/bert_layers/model.py index fd05e507..aeac1b01 100644 --- a/src/bert_layers/model.py +++ b/src/bert_layers/model.py @@ -407,6 +407,11 @@ def forward( loss_fct = nn.CrossEntropyLoss() masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx]) + + # Add MoE auxiliary losses if present + aux_loss = self._get_aux_loss() + if aux_loss is not None: + loss = loss + aux_loss assert input_ids is not None, "Coding error; please open an issue" batch, seqlen = input_ids.shape[:2] @@ -1021,6 +1026,20 @@ def __init__(self, config: FlexBertConfig): # Initialize weights and apply final processing self._init_weights(reset_params=False) + + def _get_aux_loss(self) -> Optional[torch.Tensor]: + """Collect auxiliary losses from all MoE layers in the model.""" + aux_loss = None + + # Traverse all modules to find FlexBertGLUMoE layers + for module in self.modules(): + if hasattr(module, 'aux_loss') and module.aux_loss is not None: + if aux_loss is None: + aux_loss = module.aux_loss + else: + aux_loss = aux_loss + module.aux_loss + + return aux_loss def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" @@ -1265,6 +1284,20 @@ def __init__(self, config: FlexBertConfig): # Initialize weights and apply final processing self._init_weights(reset_params=False) + def _get_aux_loss(self) -> Optional[torch.Tensor]: + """Collect auxiliary losses from all MoE layers in the model.""" + aux_loss = None + + # Traverse all modules to find FlexBertGLUMoE layers + for module in self.modules(): + if hasattr(module, 'aux_loss') and module.aux_loss is not None: + if aux_loss is None: + aux_loss = module.aux_loss + else: + aux_loss = aux_loss + module.aux_loss + + return aux_loss + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" if module: @@ -1352,6 +1385,11 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) + + # Add MoE auxiliary losses if present + aux_loss = self._get_aux_loss() + if aux_loss is not None: + loss = loss + aux_loss if not return_dict: output = (logits,) + output diff --git a/yamls/main/flex-bert-rope-base.yaml b/yamls/main/flex-bert-rope-base.yaml index d92a5e1d..b50130df 100644 --- a/yamls/main/flex-bert-rope-base.yaml +++ b/yamls/main/flex-bert-rope-base.yaml @@ -3,7 +3,7 @@ # Follow the instructions in the README to set up ./my-copy-c4 # Or point data paths to your remote C4 dataset -data_local: ./my-copy-c4 +data_local: ./c4_realnewslike data_remote: # If blank, files must be present in data_local max_seq_len: 512 @@ -112,18 +112,18 @@ optimizer: # algorithms: -max_duration: 286720000sp # Subsample the training data for ~275M samples +max_duration: 13799838sp # Subsample the training data for ~275M samples eval_interval: 2000ba -global_train_batch_size: 4096 +global_train_batch_size: 64 # System seed: 17 -device_train_microbatch_size: 128 +device_train_microbatch_size: 32 # device_train_microbatch_size: auto precision: amp_bf16 -global_eval_batch_size: 256 -device_eval_microbatch_size: 64 +global_eval_batch_size: 64 +device_eval_microbatch_size: 32 # Logging progress_bar: false @@ -139,18 +139,20 @@ callbacks: speed_monitor: window_size: 10 lr_monitor: {} +loggers: + wandb: + project: moebert + entity: # Fill this in with your W&B team/entity name + name: ${run_name} + group: rope-base-c4new -# (Optional) W&B logging -# loggers: -# wandb: -# project: # Fill this in -# entity: # Fill this in -# (Optional) Checkpoint to local filesystem or remote object store -# save_interval: 3500ba -# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -# save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) +# Checkpoint to local filesystem or remote object store +save_interval: 100000sp +save_num_checkpoints_to_keep: 2 +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: checkpoint/{run_name} # (Optional) Load from local filesystem or remote object store to # start from an existing model checkpoint; # e.g. './ckpt/latest-rank{rank}.pt' (local), or diff --git a/yamls/main/moebert-rope-base-c4-realnewslike.yaml b/yamls/main/moebert-rope-base-c4-realnewslike.yaml new file mode 100644 index 00000000..0eb6a81b --- /dev/null +++ b/yamls/main/moebert-rope-base-c4-realnewslike.yaml @@ -0,0 +1,174 @@ +# MoEBERT with RoPE - Base configuration +# Based on FlexBERT-RoPE-Base with Mixture of Experts enabled +# Using C4 RealNewsLike dataset for testing + +# Data configuration - C4 RealNewsLike +data_local: ./c4_realnewslike +data_remote: + +max_seq_len: 512 +tokenizer_name: bert-base-uncased +mlm_probability: 0.3 + +# Run Name +run_name: moebert-rope-base-13799838 + +# Model +model: + name: flex_bert + recompute_metric_loss: false + pretrained_model_name: ${tokenizer_name} + tokenizer_name: ${tokenizer_name} + model_config: + # Base architecture (ModernBERT-base) + num_attention_heads: 12 + num_hidden_layers: 12 + attention_layer: rope + attention_probs_dropout_prob: 0.0 + attn_out_bias: false + attn_out_dropout_prob: 0.0 + attn_qkv_bias: false + bert_layer: prenorm + embed_dropout_prob: 0.0 + embed_norm: false + final_norm: true + embedding_layer: sans_pos + loss_function: fa_cross_entropy + loss_kwargs: + reduction: mean + mlp_dropout_prob: 0.0 + mlp_in_bias: false + mlp_layer: glu_moe + mlp_out_bias: false + normalization: rmsnorm + norm_kwargs: + eps: 1e-6 + padding: unpadded + sparse_prediction: false + + # RoPE configuration + rotary_emb_dim: null # will be set to headdim by default + rotary_emb_base: 10000.0 + rotary_emb_scale_base: null + rotary_emb_interleaved: false + + # General settings + hidden_act: gelu + init_method: full_megatron + init_std: 0.02 + init_cutoff_factor: 2.0 + init_small_embedding: false + deterministic_fa2: false + initial_attention_layer: null + initial_bert_layer: null + initial_mlp_layer: null + num_initial_layers: 0 + skip_first_prenorm: true + + # Sparse attention settings + sliding_window: 128 + global_attn_every_n_layers: 3 + unpad_embeddings: true + pad_logits: false + + # Mixture of Experts configuration + moe_num_experts: 4 + moe_top_k: 2 + moe_use_noisy_top_k: true + moe_capacity_factor: 1.25 + moe_compute_aux_loss: true + moe_load_balance_loss_weight: 0.01 + moe_router_z_loss_weight: 0.001 + +# Dataloaders +train_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: train + tokenizer_name: ${tokenizer_name} + max_seq_len: ${max_seq_len} + shuffle: true + mlm_probability: ${mlm_probability} + drop_last: true + num_workers: 8 + +eval_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: val + tokenizer_name: ${tokenizer_name} + max_seq_len: ${max_seq_len} + shuffle: false + mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison + drop_last: false + num_workers: 8 + +# Optimization +scheduler: + name: linear_decay_with_warmup + t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration + alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration + +optimizer: + name: decoupled_adamw + lr: 5.0e-4 # Peak learning rate + betas: + - 0.9 + - 0.98 + eps: 1.0e-06 + weight_decay: 1.0e-5 # Amount of weight decay regularization + filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases + +# Training parameters - reduced for testing +max_duration: 13799838sp # Smaller sample count for testing +eval_interval: 100ba +global_train_batch_size: 8 + +# System +seed: 17 +device_train_microbatch_size: 2 +precision: amp_bf16 + +global_eval_batch_size: 16 +device_eval_microbatch_size: 4 + +# Logging +progress_bar: false +log_to_console: true +console_log_interval: 1ba + +algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0 + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + +# W&B logging +loggers: + wandb: + project: moebert + entity: # Fill this in with your W&B team/entity name + name: ${run_name} + group: rope-base-moe-c4 + + +# Checkpoint to local filesystem or remote object store +save_interval: 100000sp +save_num_checkpoints_to_keep: 2 + +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: checkpoint/{run_name} + +# (Optional) Load from local filesystem or remote object store to +# start from an existing model checkpoint; +# e.g. './ckpt/latest-rank{rank}.pt' (local), or +# 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) +# load_path: null