From 2dd9a989bd9298b6b7a5b97d382ac152dda27844 Mon Sep 17 00:00:00 2001 From: heyo66 <137907563+heyo66@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:03:59 +0100 Subject: [PATCH 1/4] Add files via upload --- configuration_bert.py | 290 +++++++ loss.py | 105 +++ mlp.py | 405 ++++++++++ model.py | 1722 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 2522 insertions(+) create mode 100644 configuration_bert.py create mode 100644 loss.py create mode 100644 mlp.py create mode 100644 model.py diff --git a/configuration_bert.py b/configuration_bert.py new file mode 100644 index 00000000..b9f1e6ce --- /dev/null +++ b/configuration_bert.py @@ -0,0 +1,290 @@ +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +import warnings + +from transformers import BertConfig as TransformersBertConfig + + +class BertConfig(TransformersBertConfig): + def __init__( + self, + alibi_starting_size: int = 512, + normalization: str = "layernorm", + attention_probs_dropout_prob: float = 0.0, + head_pred_act: str = "gelu", + deterministic_fa2: bool = False, + allow_embedding_resizing: bool = False, + **kwargs, + ): + """Configuration class for MosaicBert. + + Args: + alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to + create when initializing the model. You should be able to ignore this parameter in most cases. + Defaults to 512. + attention_probs_dropout_prob (float): By default, turn off attention dropout in MosaicBERT + Note that the custom Triton Flash Attention with ALiBi implementation does not support droput. + However, Flash Attention 2 supports ALiBi and dropout https://github.com/Dao-AILab/flash-attention + embed_dropout_prob (float): Dropout probability for the embedding layer. + attn_out_dropout_prob (float): Dropout probability for the attention output layer. + mlp_dropout_prob (float): Dropout probability for the MLP layer. + allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size. + """ + super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) + self.alibi_starting_size = alibi_starting_size + self.normalization = normalization + self.head_pred_act = head_pred_act + self.deterministic_fa2 = deterministic_fa2 + self.allow_embedding_resizing = allow_embedding_resizing + + +class FlexBertConfig(TransformersBertConfig): + def __init__( + self, + attention_layer: str = "base", + attention_probs_dropout_prob: float = 0.0, + attn_out_bias: bool = False, + attn_out_dropout_prob: float = 0.0, + attn_qkv_bias: bool = False, + bert_layer: str = "prenorm", + decoder_bias: bool = True, + embed_dropout_prob: float = 0.0, + embed_norm: bool = True, + final_norm: bool = False, + embedding_layer: str = "absolute_pos", + encoder_layer: str = "base", + loss_function: str = "cross_entropy", + loss_kwargs: dict = {}, + mlp_dropout_prob: float = 0.0, + mlp_in_bias: bool = False, + mlp_layer: str = "glu_moe", + mlp_out_bias: bool = False, + norm_kwargs: dict = {}, + normalization: str = "rmsnorm", + padding: str = "unpadded", + head_class_act: str = "silu", + head_class_bias: bool = False, + head_class_dropout: float = 0.0, + head_class_norm: str = False, + head_pred_act: str = "silu", + head_pred_bias: bool = False, + head_pred_dropout: float = 0.0, + head_pred_norm: bool = True, + pooling_type: str = "cls", + rotary_emb_dim: int | None = None, + rotary_emb_base: float = 10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved: bool = False, + use_fa2: bool = True, + use_sdpa_attn_mask: bool = False, + allow_embedding_resizing: bool = False, + init_method: str = "default", + init_std: float = 0.02, + init_cutoff_factor: float = 2.0, + init_small_embedding: bool = False, + initial_attention_layer: str | None = None, + initial_bert_layer: str | None = None, + initial_mlp_layer: str | None = None, + num_initial_layers: int = 1, + skip_first_prenorm: bool = False, + deterministic_fa2: bool = False, + sliding_window: int = -1, + global_attn_every_n_layers: int = -1, + local_attn_rotary_emb_base: float = -1, + local_attn_rotary_emb_dim: int | None = None, + unpad_embeddings: bool = False, + pad_logits: bool = False, + compile_model: bool = False, + masked_prediction: bool = False, + moe_num_experts: int = 8, + moe_top_k: int = 2, + moe_use_noisy_top_k: bool = True, + moe_capacity_factor: float = 1.25, + moe_compute_aux_loss: bool = True, + moe_load_balance_loss_weight: float = 0.01, + moe_router_z_loss_weight: float = 0.001, + **kwargs, + ): + """ + Args: + attention_layer (str): Attention layer type. + attention_probs_dropout_prob (float): Dropout probability for attention probabilities. + attn_out_bias (bool): use bias in attention output projection. + attn_out_dropout_prob (float): Dropout probability for attention output. + attn_qkv_bias (bool): use bias for query, key, value linear layer(s). + bert_layer (str): BERT layer type. + decoder_bias (bool): use bias in decoder linear layer. + embed_dropout_prob (float): Dropout probability for embeddings. + embed_norm (bool): Normalize embedding output. + final_norm (bool): Add normalization after the final encoder layer and before head. + embedding_layer (str): Embedding layer type. + encoder_layer (str): Encoder layer type. + loss_function (str): Loss function to use. + loss_kwargs (dict): Keyword arguments for loss function. + mlp_dropout_prob (float): Dropout probability for MLP layers. + mlp_in_bias (bool): Use bias in MLP input linear layer. + mlp_layer (str): MLP layer type. + mlp_out_bias (bool): Use bias in MLP output linear layer. + norm_kwargs (dict): Keyword arguments for normalization layers. + normalization (str): Normalization type. + padding (str): Unpad inputs. Best with `use_fa2=True`. + head_class_act (str): Activation function for classification head. + head_class_bias (bool): Use bias in classification head linear layer(s). + head_class_dropout (float): Dropout probability for classification head. + head_class_norm (str): Normalization type for classification head. + head_pred_act (str): Activation function for prediction head. + head_pred_bias (bool): Use bias in prediction head linear layer(s). + head_pred_dropout (float): Dropout probability for prediction head. + head_pred_norm (bool): Normalize prediction head output. + pooling_type (str): Pooling type. + rotary_emb_dim (int | None): Rotary embedding dimension. + rotary_emb_base (float): Rotary embedding base. + rotary_emb_scale_base (float): Rotary embedding scale base. + rotary_emb_interleaved (bool): Use interleaved rotary embeddings. + use_fa2 (bool): Use FlashAttention2. Requires flash_attn package. + use_sdpa_attn_mask (bool): Pass a mask to SDPA. This will prevent SDPA from using the PyTorch FA2 kernel. + allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size. + init_method (str): Model layers initialization method. + init_std (float): Standard deviation for initialization. Used for normal and full_megatron init. + init_cutoff_factor (float): Cutoff factor for initialization. Used for normal and full_megatron init. + init_small_embedding (bool): Initialize embeddings with RWKV small init. + initial_attention_layer (str | None): Replace first `num_initial_layers` attention_layer instance with this layer. + initial_bert_layer (str | None): Replace first `num_initial_layers` bert_layer instance with this layer. + initial_mlp_layer (str | None): Replace first `num_initial_layers` mlp_layer instance with this layer. + num_initial_layers (int): Number of initial layers to set via `initial_attention_layer`, `initial_bert_layer`, and `initial_mlp_layer`. + skip_first_prenorm (bool): Skip pre-normalization for the first bert layer. Requires `embed_norm=True`. + deterministic_fa2 (bool): Use Flash Attention 2 deterministic mode. This is slower then the default non-deterministic mode. + sliding_window (int): Use sliding window attention with window size `n`. -1 to disable. Window size split between the left and right context. Only supports FA2. + global_attn_every_n_layers (int): Use global attention every `n` layers and sliding window for the rest. -1 to disable. + local_attn_rotary_emb_base (float): Rotary embedding base for local attention. -1 to disable and use `rotary_emb_base` for all layers. + local_attn_rotary_emb_dim (int | None): Rotary embedding dimension for local attention. None to disable and use `rotary_emb_dim` for all layers. + unpad_embeddings (bool): Unpad inputs before the embedding layer. + pad_logits (bool): Pad logits after the calculating the loss. + compile_model (bool): Compile the subset of the model which can be compiled. + masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers + moe_num_experts (int): Number of experts for Mixture of Experts layers. + moe_top_k (int): Number of top experts to select for each token in MoE. + moe_use_noisy_top_k (bool): Use noisy top-k gating for exploration during training. + moe_capacity_factor (float): Capacity factor for expert assignment in MoE. + moe_compute_aux_loss (bool): Whether to compute and add auxiliary losses for MoE. + moe_load_balance_loss_weight (float): Weight for the load balancing auxiliary loss. + moe_router_z_loss_weight (float): Weight for the router z-loss auxiliary loss. + **kwargs: Additional keyword arguments. + """ + super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) + self.attention_layer = attention_layer + self.attn_out_bias = attn_out_bias + self.attn_out_dropout_prob = attn_out_dropout_prob + self.attn_qkv_bias = attn_qkv_bias + self.bert_layer = bert_layer + self.decoder_bias = decoder_bias + self.embed_dropout_prob = embed_dropout_prob + self.embed_norm = embed_norm + self.final_norm = final_norm + self.embedding_layer = embedding_layer + self.encoder_layer = encoder_layer + self.loss_function = loss_function + self.loss_kwargs = loss_kwargs + self.mlp_dropout_prob = mlp_dropout_prob + self.mlp_in_bias = mlp_in_bias + self.mlp_layer = mlp_layer + self.mlp_out_bias = mlp_out_bias + self.norm_kwargs = norm_kwargs + self.normalization = normalization + self.padding = padding + self.head_class_act = head_class_act + self.head_class_bias = head_class_bias + self.head_class_dropout = head_class_dropout + self.head_class_norm = head_class_norm + self.head_pred_act = head_pred_act + self.head_pred_bias = head_pred_bias + self.head_pred_dropout = head_pred_dropout + self.head_pred_norm = head_pred_norm + self.pooling_type = pooling_type + self.rotary_emb_dim = rotary_emb_dim + self.rotary_emb_base = rotary_emb_base + self.rotary_emb_scale_base = rotary_emb_scale_base + self.rotary_emb_interleaved = rotary_emb_interleaved + self.use_fa2 = use_fa2 + self.use_sdpa_attn_mask = use_sdpa_attn_mask + self.allow_embedding_resizing = allow_embedding_resizing + self.init_method = init_method + self.init_std = init_std + self.init_cutoff_factor = init_cutoff_factor + self.init_small_embedding = init_small_embedding + self.initial_attention_layer = initial_attention_layer + self.initial_bert_layer = initial_bert_layer + self.initial_mlp_layer = initial_mlp_layer + self.num_initial_layers = num_initial_layers + self.skip_first_prenorm = skip_first_prenorm + self.deterministic_fa2 = deterministic_fa2 + self.sliding_window = sliding_window + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attn_rotary_emb_base = local_attn_rotary_emb_base + self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim + self.unpad_embeddings = unpad_embeddings + self.pad_logits = pad_logits + self.compile_model = compile_model + self.masked_prediction = masked_prediction + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_use_noisy_top_k = moe_use_noisy_top_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_compute_aux_loss = moe_compute_aux_loss + self.moe_load_balance_loss_weight = moe_load_balance_loss_weight + self.moe_router_z_loss_weight = moe_router_z_loss_weight + + if loss_kwargs.get("return_z_loss", False): + if loss_function != "fa_cross_entropy": + raise ValueError("loss_function must be 'fa_cross_entropy' when return_z_loss is True") + if loss_kwargs.get("lse_square_scale", 0) <= 0: + raise ValueError( + "lse_square_scale must be passed to `loss_kwargs` and must be greater than 0 for z_loss" + ) + if loss_kwargs.get("inplace_backward", False): + self.loss_kwargs["inplace_backward"] = False + warnings.warn("`inplace_backward=True` will cause incorrect metrics. Automatically setting to False.") + + if global_attn_every_n_layers > 0 and (self.num_hidden_layers - 1) % global_attn_every_n_layers != 0: + raise ValueError( + f"{global_attn_every_n_layers=} must be a divisor of one less than {self.num_hidden_layers=}" + ) + + if self.sliding_window != -1: + if not self.use_fa2: + raise ValueError("Sliding window attention is only supported with FlashAttention2") + if self.sliding_window % 2 != 0 and self.sliding_window % 64 != 0: + raise ValueError( + f"Sliding window must be an even number and divisible by 64: {self.sliding_window=} {self.sliding_window % 64} {self.sliding_window % 2}" + ) + else: + if self.global_attn_every_n_layers != -1: + raise ValueError("global_attn_every_n_layers must be -1 when sliding_window is disabled") + if self.local_attn_rotary_emb_base != -1: + raise ValueError("local_attn_rotary_emb_base must be -1 when sliding_window is disabled") + if self.local_attn_rotary_emb_dim is not None: + raise ValueError("local_attn_rotary_emb_dim must be None when sliding_window is disabled") + + if self.unpad_embeddings and self.padding != "unpadded": + warnings.warn( + "`unpad_embeddings=True` requires `padding='unpadded'`. Automatically setting `padding='unpadded'`." + ) + self.padding = "unpadded" + if self.pad_logits and not self.unpad_embeddings: + raise ValueError("`pad_logits=True` requires `unpad_embeddings=True`") + if self.unpad_embeddings and self.embedding_layer == "absolute_pos": + raise ValueError(f"{self.unpad_embeddings=} is incompatible with {self.embedding_layer=}") + + +PADDING = ["unpadded", "padded"] + + +def maybe_add_padding(config: FlexBertConfig, config_option: str) -> str: + if config.padding not in PADDING: + raise ValueError(f"Invalid padding type: {config.padding}, must be one of {PADDING}") + + if not any(config_option.startswith(pad + "_") for pad in PADDING): + config_option = f"{config.padding}_{config_option}" + + return config_option diff --git a/loss.py b/loss.py new file mode 100644 index 00000000..3f210a4c --- /dev/null +++ b/loss.py @@ -0,0 +1,105 @@ +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 + +import inspect +import torch +import torch.nn as nn +import torch.nn.functional as F +from .configuration_bert import FlexBertConfig + +try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss +except ImportError: + CrossEntropyLoss = None + +LOSS2CLS = { + "cross_entropy": nn.CrossEntropyLoss, + "binary_cross_entropy": nn.BCEWithLogitsLoss, + "mean_squared_error": nn.MSELoss, +} + +if CrossEntropyLoss is not None: + LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss + + +class MoELoadBalancingLoss(nn.Module): + """Computes Switch Transformer auxiliary loss for load balancing. + + Reference: https://arxiv.org/abs/2101.03961 (equations 4-6, page 7) + + This loss encourages balanced token allocation across experts to avoid + scenarios where some experts are overloaded while others are underutilized. + """ + + def __init__(self, num_experts: int, top_k: int = 2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + + def forward( + self, + router_logits: torch.Tensor, + expert_indices: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + router_logits: Router logits [batch_size, seq_len, num_experts] + expert_indices: Top-k expert indices [batch_size, seq_len, top_k] + + Returns: + load_balance_loss: Scalar loss value + """ + # Compute expert probabilities + expert_probs = F.softmax(router_logits, dim=-1) # [B, C, n_exp] + + # Equation (5): compute ratio of tokens allocated to each expert + with torch.no_grad(): + one_hot_indices = F.one_hot(expert_indices, num_classes=self.num_experts) # [B, C, K, n_exp] + one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, C, n_exp] + tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1)) # [n_exp] + + # Equation (6): compute ratio of router probability allocated to each expert + prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1)) # [n_exp] + + # Equation (4): scaled dot product between prob / token allocation vectors + load_balance_loss = self.num_experts * torch.sum(prob_per_expert * tokens_per_expert) + + return load_balance_loss + + +class MoERouterZLoss(nn.Module): + """Computes router z-loss for MoE models. + + Reference: https://arxiv.org/abs/2202.08906 (equation 5, page 7) + + This loss constrains the size of router logits to avoid numerical instability + during training. Large logits can lead to round-off errors in the softmax computation, + even in float32 precision. + """ + + def forward(self, router_logits: torch.Tensor) -> torch.Tensor: + """ + Args: + router_logits: Router logits [batch_size, seq_len, num_experts] + + Returns: + router_z_loss: Scalar loss value + """ + # Numerically stable computation: logsumexp is equivalent to log(sum(exp(x))) + # This avoids overflow issues from directly exponentiating large logits + router_z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, C] + + # Average over all tokens + router_z_loss = torch.mean(router_z_loss) + + return router_z_loss + + +def get_loss_fn(config: FlexBertConfig) -> nn.Module: + try: + loss_class = LOSS2CLS[config.loss_function] + signature = inspect.signature(loss_class) + loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters} + return loss_class(**loss_kwargs) + except KeyError: + raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.") diff --git a/mlp.py b/mlp.py new file mode 100644 index 00000000..f344008d --- /dev/null +++ b/mlp.py @@ -0,0 +1,405 @@ +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .configuration_bert import FlexBertConfig +from .activation import get_act_fn +from .normalization import get_norm_layer +from .initialization import ModuleType, init_weights +from .loss import MoELoadBalancingLoss, MoERouterZLoss + + +class BertResidualGLU(nn.Module): + """Applies the FFN at the end of each Mosaic BERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but + introduces Gated Linear Units. + + Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a + standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with + `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed + with the `config.intermediate_size=3072`. + However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased + parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. + """ + + def __init__( + self, + config, + ): + super().__init__() + self.config = config + self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) + self.act = get_act_fn(config.hidden_act) + self.wo = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.layernorm = get_norm_layer(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Compute new hidden states from current hidden states. + + Args: + hidden_states (torch.Tensor): The (unpadded) hidden states from + the attention layer [nnz, dim]. + """ + residual_connection = hidden_states + # compute the activation + hidden_states = self.gated_layers(hidden_states) + gated = hidden_states[:, : self.config.intermediate_size] + non_gated = hidden_states[:, self.config.intermediate_size :] + hidden_states = self.act(gated) * non_gated + hidden_states = self.dropout(hidden_states) + # multiply by the second matrix + hidden_states = self.wo(hidden_states) + # add the residual connection and post-LN + hidden_states = self.layernorm(hidden_states + residual_connection) + return hidden_states + + +class FlexBertMLPBase(nn.Module): + """A FlexBERT MLP base class for type hints.""" + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + def _init_weights(self, reset_params: bool = False): + raise NotImplementedError("This is a base class and should not be used directly.") + + def reset_parameters(self): + self._init_weights(reset_params=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("This is a base class and should not be used directly.") + + +class FlexBertMLP(FlexBertMLPBase): + """Applies the MLP at the end of each FlexBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__(config=config, layer_id=layer_id) + self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias) + self.act = get_act_fn(config.hidden_act) + self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) + + def _init_weights(self, reset_params: bool = False): + init_weights( + self.config, + self.Wi, + layer_dim=self.config.hidden_size, + layer_id=None, + type_of_module=ModuleType.in_module, + ) + init_weights( + self.config, + self.Wo, + layer_dim=self.config.intermediate_size, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Compute new hidden states from current hidden states. + + Args: + hidden_states (torch.Tensor): The (unpadded) hidden states from + the attention layer [nnz, dim]. + """ + return self.Wo(self.drop(self.act(self.Wi(hidden_states)))) + + +class FlexBertGLU(FlexBertMLPBase): + """Applies the GLU at the end of each FlexBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__(config=config, layer_id=layer_id) + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias) + self.act = get_act_fn(config.hidden_act) + self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) + + def _init_weights(self, reset_params: bool = False): + init_weights( + self.config, + self.Wi, + layer_dim=self.config.hidden_size, + layer_id=None, + type_of_module=ModuleType.in_module, + ) + init_weights( + self.config, + self.Wo, + layer_dim=self.config.intermediate_size, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class FlexBertParallelGLU(FlexBertMLPBase): + """Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__(config=config, layer_id=layer_id) + self.act = get_act_fn(config.hidden_act) + self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) + + def _init_weights(self, reset_params: bool = False): + init_weights( + self.config, + self.Wo, + layer_dim=self.config.intermediate_size, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + + def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: + input, gate = intermediate_ff.chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class Router(nn.Module): + """Top-K router for selecting experts.""" + + def __init__( + self, + d: int, + n_exp: int, + top_k: int = 2, + use_noisy_top_k: bool = True, + capacity_factor: float = 1.25, + ): + super().__init__() + self.d = d + self.n_exp = n_exp + self.top_k = top_k + self.use_noisy_top_k = use_noisy_top_k + self.capacity_factor = capacity_factor + + # Router weights to compute logits for each expert + self.gate = nn.Linear(d, n_exp, bias=False) + + # Noise parameters for noisy top-k gating + if use_noisy_top_k: + self.w_noise = nn.Linear(d, n_exp, bias=False) + + def forward(self, x: torch.Tensor): + """ + Args: + x: [batch_size, seq_len, d] + Returns: + exp_weight: Expert weights [batch_size * seq_len, top_k] + exp_mask: Expert mask [batch_size * seq_len, n_exp, exp_capacity] + exp_batches: Token assignments [n_exp, exp_capacity, d] + """ + B, C, d = x.size() + num_tokens = B * C + x_flat = x.view(num_tokens, d) + + # Compute router logits + logits = self.gate(x_flat) # [num_tokens, n_exp] + + # Add noise for exploration (optional) + if self.use_noisy_top_k and self.training: + noise_stddev = F.softplus(self.w_noise(x_flat)) + noise = torch.randn_like(logits) * noise_stddev + logits = logits + noise + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) + top_k_gates = F.softmax(top_k_logits, dim=-1) # [num_tokens, top_k] + + # Compute expert capacity + exp_capacity = int((num_tokens * self.top_k * self.capacity_factor) / self.n_exp) + + # Create expert assignment mask and batches + exp_mask = torch.zeros(num_tokens, self.n_exp, exp_capacity, device=x.device) + exp_batches = torch.zeros(self.n_exp, exp_capacity, d, device=x.device) + + # Count tokens assigned to each expert + expert_counts = torch.zeros(self.n_exp, dtype=torch.long, device=x.device) + + # Assign tokens to experts + for token_idx in range(num_tokens): + for k_idx in range(self.top_k): + expert_idx = top_k_indices[token_idx, k_idx] + if expert_counts[expert_idx] < exp_capacity: + pos = expert_counts[expert_idx] + exp_mask[token_idx, expert_idx, pos] = top_k_gates[token_idx, k_idx] + exp_batches[expert_idx, pos] = x_flat[token_idx] + expert_counts[expert_idx] += 1 + + return top_k_gates, exp_mask, exp_batches + + + + + +class FlexBertGLUMoE(FlexBertMLPBase): + """Mixture of Experts with GLU activation for FlexBERT.""" + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__(config=config, layer_id=layer_id) + + self.n_exp = getattr(config, 'moe_num_experts', 4) + self.top_k = getattr(config, 'moe_top_k', 2) + self.use_noisy_top_k = getattr(config, 'moe_use_noisy_top_k', True) + self.capacity_factor = getattr(config, 'moe_capacity_factor', 1.25) + self.compute_aux_loss = getattr(config, 'moe_compute_aux_loss', True) + self.load_balance_loss_weight = getattr(config, 'moe_load_balance_loss_weight', 0.01) + self.router_z_loss_weight = getattr(config, 'moe_router_z_loss_weight', 0.001) + + self.router = Router( + d=config.hidden_size, + n_exp=self.n_exp, + top_k=self.top_k, + use_noisy_top_k=self.use_noisy_top_k, + capacity_factor=self.capacity_factor, + ) + + # GLU experts (each projects to 2x intermediate size) + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=config.mlp_in_bias), + nn.Identity(), # Placeholder for chunking + activation + nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity(), + nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias), + ) + for _ in range(self.n_exp) + ]) + self.act = get_act_fn(config.hidden_act) + + # Initialize auxiliary loss modules + if self.compute_aux_loss: + self.load_balance_loss = MoELoadBalancingLoss(num_experts=self.n_exp, top_k=self.top_k) + self.router_z_loss = MoERouterZLoss() + + def _init_weights(self, reset_params: bool = False): + init_weights( + self.config, + self.router.gate, + layer_dim=self.config.hidden_size, + layer_id=self.layer_id, + type_of_module=ModuleType.in_module, + ) + + for expert in self.experts: + for i, module in enumerate(expert): + if isinstance(module, nn.Linear): + init_weights( + self.config, + module, + layer_dim=self.config.hidden_size if i == 0 else self.config.intermediate_size, + layer_id=self.layer_id, + type_of_module=ModuleType.in_module if i == 0 else ModuleType.out_module, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + original_shape = hidden_states.shape + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + B, C, d = hidden_states.size() + num_tokens = B * C + x_flat = hidden_states.view(num_tokens, d) + + # Compute router logits for auxiliary loss calculation + router_logits = self.router.gate(x_flat) # [num_tokens, n_exp] + + exp_weight, exp_mask, exp_batches = self.router(hidden_states) + + # Extract top-k indices from router for load balancing loss + _, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) # [num_tokens, top_k] + + # Apply GLU experts + exp_out = torch.zeros_like(exp_batches) + for i, expert in enumerate(self.experts): + x = expert[0](exp_batches[i]) # Linear projection + input, gate = x.chunk(2, dim=-1) # Split for GLU + x = self.act(input) * gate # GLU activation + x = expert[2](x) # Dropout + exp_out[i] = expert[3](x) # Output projection + + exp_weight_flat = exp_mask.view(num_tokens, -1) + exp_out_flat = exp_out.view(-1, d) + output = torch.matmul(exp_weight_flat, exp_out_flat) + + # Compute auxiliary losses + self.aux_loss = None + if self.compute_aux_loss: + # Reshape for loss computation + router_logits_reshaped = router_logits.view(B, C, -1) + top_k_indices_reshaped = top_k_indices.view(B, C, -1) + + # Compute load balancing loss + lb_loss = self.load_balance_loss(router_logits_reshaped, top_k_indices_reshaped) + + # Compute router z-loss + z_loss = self.router_z_loss(router_logits_reshaped) + + # Combine auxiliary losses with weights + self.aux_loss = self.load_balance_loss_weight * lb_loss + self.router_z_loss_weight * z_loss + + return output.view(*original_shape) + + +# Update the MLP registry +MLP2CLS = { + "mlp": FlexBertMLP, + "glu": FlexBertGLU, + "parallel_glu": FlexBertParallelGLU, + "glu_moe": FlexBertGLUMoE, +} + + +def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase: + try: + mlp_layer = ( + config.initial_mlp_layer + if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None + else config.mlp_layer + ) + return MLP2CLS[mlp_layer](config, layer_id=layer_id) + except KeyError as e: + if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None: + raise ValueError( + f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}" + ) + else: + raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") + + diff --git a/model.py b/model.py new file mode 100644 index 00000000..aeac1b01 --- /dev/null +++ b/model.py @@ -0,0 +1,1722 @@ +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 + +# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation) +# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT + +# Copyright 2022 Jonas Geiping +# License: MIT + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + +"""Implements Mosaic BERT, with an eye towards the Hugging Face API. + +Mosaic BERT improves performance over Hugging Face BERT through the following: + +1. ALiBi. This architectural change removes positional embeddings and instead encodes positional +information through attention biases based on query-key position distance. It improves the effectiveness +of training with shorter sequence lengths by enabling extrapolation to longer sequences. + +2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer +to improve overall expressiveness, providing better convergence properties. + +3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically +improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that +supports attention biases, which allows us to use Flash Attention with ALiBi. + +4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT +implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation +and improve speed. It does this without changing how the user interfaces with the model, thereby +preserving the simple API of standard implementations. + + +Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence +classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases. + +See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage +of the core Mosaic BERT classes. +""" + +import logging +import os +import sys +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from +sys.path.append(os.path.dirname(os.path.realpath(__file__))) + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present +from transformers.modeling_outputs import ( + MaskedLMOutput, + ModelOutput, + MultipleChoiceModelOutput, + SequenceClassifierOutput, +) +from transformers.models.bert.modeling_bert import BertPreTrainedModel + +from bert_padding import index_put_first_axis + +from src.bert_layers.activation import get_act_fn +from src.bert_layers.attention import ( + FlexBertPaddedAttention, + FlexBertPaddedParallelAttention, + FlexBertPaddedRopeAttention, + FlexBertPaddedRopeParallelAttention, + FlexBertUnpadAttention, + FlexBertUnpadParallelAttention, + FlexBertUnpadRopeAttention, + FlexBertUnpadRopeParallelAttention, +) +from src.bert_layers.configuration_bert import FlexBertConfig +from src.bert_layers.embeddings import ( + BertAlibiEmbeddings, + FlexBertAbsoluteEmbeddings, + FlexBertCompiledSansPositionEmbeddings, + FlexBertSansPositionEmbeddings, + get_embedding_layer, +) +from src.bert_layers.initialization import ( + ModuleType, + TileLinear, + TileMode, + init_weights, + tile_embedding, + tile_linear, + tile_norm, +) +from src.bert_layers.layers import ( + BertAlibiEncoder, + BertPooler, + BertPredictionHeadTransform, + FlexBertCompileUnpadPreNormLayer, + FlexBertPaddedEncoder, + FlexBertPaddedParallelPreNormLayer, + FlexBertPaddedPostNormLayer, + FlexBertPaddedPreNormLayer, + FlexBertUnpadEncoder, + FlexBertUnpadParallelPreNormLayer, + FlexBertUnpadPostNormLayer, + FlexBertUnpadPreNormLayer, + get_encoder_layer, +) +from src.bert_layers.loss import get_loss_fn +from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU +from src.bert_layers.normalization import get_norm_layer +from src.bert_layers.padding import pad_input, unpad_input + +logger = logging.getLogger(__name__) + + +def _count_parameters(model: nn.Module, trainable: bool = True) -> int: + if trainable: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + else: + return sum(p.numel() for p in model.parameters()) + + +class BertModel(BertPreTrainedModel): + """Overall BERT model. + + Args: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controlled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + model = BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__( + self, + config, + add_pooling_layer: bool = True, + ): + super(BertModel, self).__init__(config) + self.embeddings = BertAlibiEmbeddings(config) + self.encoder = BertAlibiEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_all_encoded_layers: Optional[bool] = False, + masked_tokens_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + + subset_mask = [] + first_col_mask = [] + + if masked_tokens_mask is None: + subset_mask = None + else: + first_col_mask = torch.zeros_like(masked_tokens_mask) + first_col_mask[:, 0] = True + subset_mask = masked_tokens_mask | first_col_mask + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + subset_mask=subset_mask, + ) + + if masked_tokens_mask is None: + sequence_output = encoder_outputs[-1] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + else: + # TD [2022-03-01]: the indexing here is very tricky. + attention_mask_bool = attention_mask.bool() + subset_idx = subset_mask[attention_mask_bool] # type: ignore + sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]] + if self.pooler is not None: + pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]] + pooled_output = self.pooler(pool_input, pool=False) + else: + pooled_output = None + + if not output_all_encoded_layers: + encoder_outputs = sequence_output + + if self.pooler is not None: + return encoder_outputs, pooled_output + + return encoder_outputs, None + + +################### +# Bert Heads +################### +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0)) + self.decoder.weight = bert_model_embedding_weights + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super().__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +##################### +# Various Bert models +##################### + + +class BertForPreTraining(BertPreTrainedModel): + # TBD: Coming in Future Commit + pass + + +class BertLMHeadModel(BertPreTrainedModel): + # TBD: Coming in Future Commit + pass + + +class BertForMaskedLM(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + warnings.warn( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_composer( + cls, + pretrained_checkpoint, + state_dict=None, + cache_dir=None, + from_tf=False, + config=None, + *inputs, + **kwargs, + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + # labels should be a `torch.LongTensor` of shape + # `(batch_size, sequence_length)`. These are used for computing the + # masked language modeling loss. + # + # Indices should be in `[-100, 0, ..., config.vocab_size]` (see + # `input_ids` docstring) Tokens with indices set to `-100` are ignored + # (masked), the loss is only computed for the tokens with labels in `[0, + # ..., config.vocab_size]` + # + # Prediction scores are only computed for masked tokens and the (bs, + # seqlen) dimensions are flattened + if (input_ids is not None) == (inputs_embeds is not None): + raise ValueError("Must specify either input_ids or input_embeds!") + + if labels is None: + masked_tokens_mask = None + else: + masked_tokens_mask = labels > 0 + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + masked_tokens_mask=masked_tokens_mask, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + loss = None + if labels is not None: + # Compute loss + loss_fct = nn.CrossEntropyLoss() + masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() + loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx]) + + # Add MoE auxiliary losses if present + aux_loss = self._get_aux_loss() + if aux_loss is not None: + loss = loss + aux_loss + + assert input_ids is not None, "Coding error; please open an issue" + batch, seqlen = input_ids.shape[:2] + prediction_scores = rearrange( + index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen), + "(b s) d -> b s d", + b=batch, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=None, + attentions=None, + ) + + def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat( + [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], + dim=-1, + ) + dummy_token = torch.full( + (effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +class BertForNextSentencePrediction(BertPreTrainedModel): + # TBD: Push in future commit + pass + + +class BertForSequenceClassification(BertPreTrainedModel): + """Bert Model transformer with a sequence classification/regression head. + + This head is just a linear layer on top of the pooled output. Used for, + e.g., GLUE tasks. + """ + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_composer( + cls, + pretrained_checkpoint, + state_dict=None, + cache_dir=None, + from_tf=False, + config=None, + *inputs, + **kwargs, + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. + # Indices should be in `[0, ..., config.num_labels - 1]`. + # If `config.num_labels == 1` a regression loss is computed + # (mean-square loss). If `config.num_labels > 1` a classification loss + # is computed (cross-entropy). + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + # Compute loss + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + + +class BertForMultipleChoice(BertPreTrainedModel): + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """ + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + + # In multiple choice tasks, all choices are submitted in a batch, and + # we compute a logit for each option independently. The logits are then + # normalized in the forward pass to get a probability distribution over + # the choices. + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_composer( + cls, + pretrained_checkpoint, + state_dict=None, + cache_dir=None, + from_tf=False, + config=None, + *inputs, + **kwargs, + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=None, + attentions=None, + ) + + +class BertForTokenClassification(BertPreTrainedModel): + # TBD: Push in future commit + pass + + +class BertForQuestionAnswering(BertPreTrainedModel): + """Bert Model with a span classification head. + + This is used for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden states' output to compute `span start logits` + and `span end logits`). + """ + + # TBD: Push in future commit + + +################### +# FlexBert Heads +################### + + +class FlexBertPredictionHead(nn.Module): + def __init__(self, config: FlexBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_pred_bias) + self.act = get_act_fn(config.head_pred_act) if config.head_pred_act else nn.Identity() + self.norm = ( + get_norm_layer(config, compiled_norm=config.compile_model) if config.head_pred_norm else nn.Identity() + ) + + def _init_weights(self, reset_params: bool = False): + if reset_params: + self.norm.reset_parameters() + init_weights(self.config, self.dense, layer_dim=self.config.hidden_size, type_of_module=ModuleType.in_module) + + def reset_parameters(self): + self._init_weights(reset_params=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +class FlexBertPoolingHead(nn.Module): + def __init__(self, config: FlexBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_class_bias) + self.act = get_act_fn(config.head_class_act) if config.head_class_act else nn.Identity() + self.norm = get_norm_layer(config) if config.head_class_norm else nn.Identity() + self.drop = torch.nn.Dropout(config.head_class_dropout) if config.head_class_dropout > 0 else nn.Identity() + self.pooling_type = config.pooling_type + + def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: + if pool: + if self.pooling_type == "cls": + output = hidden_states[:, 0] + elif self.pooling_type == "mean": + output = hidden_states.mean(dim=1) + elif self.pooling_type == "max": + output = hidden_states.max(dim=1)[0] + else: + output = hidden_states + + return self.drop(self.norm(self.act(self.dense(output)))) + + def _init_weights(self, reset_params: bool = False): + init_weights(self.config, self.dense, self.config.hidden_size, type_of_module=ModuleType.out_module) + if reset_params and hasattr(self.norm, "reset_parameters"): + self.norm.reset_parameters() + + def reset_parameters(self): + self._init_weights(reset_params=True) + + +################### +# FlexBert Models +################### + + +@dataclass +class MaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + indices: Optional[torch.LongTensor] = None + cu_seqlens: Optional[torch.LongTensor] = None + max_seqlen: Optional[int] = None + batch_size: Optional[int] = None + seq_len: Optional[int] = None + labels: Optional[torch.LongTensor] = None + + +@dataclass +class MaskedLMOutputZLoss(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + ce_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Cross entropy loss. + z_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Z loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + indices (`torch.LongTensor` of shape `(batch_size,)`): + Indices of the tokens to be masked. + """ + + loss: Optional[torch.FloatTensor] = None + ce_loss: Optional[torch.FloatTensor] = None + z_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + indices: Optional[torch.LongTensor] = None + cu_seqlens: Optional[torch.LongTensor] = None + max_seqlen: Optional[int] = None + batch_size: Optional[int] = None + seq_len: Optional[int] = None + labels: Optional[torch.LongTensor] = None + + +class FlexBertPreTrainedModel(BertPreTrainedModel): + """ + An abstract class to handle custom weights initialization of modules + """ + + def _init_module_weights(self, module: nn.Module): + """ + Custom weight init of modules using src.bert_layers.initialization.init_weights + Currently only supports init of embedding modules + """ + assert isinstance(module, nn.Module) + if isinstance(module, nn.Embedding): + init_weights(self.config, module, type_of_module=ModuleType.emb) + else: + raise NotImplementedError("Custom weight init for the given module is not supported") + + +class FlexBertModel(FlexBertPreTrainedModel): + """Overall BERT model. + + Args: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controlled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + model = BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config: FlexBertConfig): + super().__init__(config) + self.embeddings = get_embedding_layer(config) + self.encoder = get_encoder_layer(config) + if config.final_norm: + # if we use prenorm attention we need to add a final norm + self.final_norm = get_norm_layer(config) + else: + self.final_norm = None + self.unpad_embeddings = config.unpad_embeddings + + def post_init(self): + self._init_weights(reset_params=False) + self._backward_compatibility_gradient_checkpointing() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + embedding_output = self.embeddings(input_ids, position_ids) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=attention_mask, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + if self.final_norm is not None: + encoder_outputs = self.final_norm(encoder_outputs) + return encoder_outputs + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + self._init_module_weights(module) + else: + assert isinstance(reset_params, bool) + self.embeddings._init_weights(reset_params=reset_params) + self.encoder._init_weights(reset_params=reset_params) + + if reset_params and self.config.final_norm: + self.final_norm.reset_parameters() + + def reset_parameters(self): + self._init_weights(reset_params=True) + + def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: + """Returns the number of parameters in the model. + + Args: + count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. + trainable: only count trainable parameters. + """ + params = sum([_count_parameters(layer, trainable) for layer in self.encoder.layers]) + if count_embeddings: + params += _count_parameters(self.embeddings, trainable) + if hasattr(self.embeddings, "position_embeddings"): + params -= _count_parameters(self.embeddings.position_embeddings, trainable) + return params + + +class FlexBertForMaskedLM(FlexBertPreTrainedModel): + def __init__(self, config: FlexBertConfig): + super().__init__(config) + self.bert = FlexBertModel(config) + self.head = FlexBertPredictionHead(config) + + if config.tie_word_embeddings: + decoder_weights = self.bert.embeddings.tok_embeddings.weight + else: + decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight + self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) + self.decoder.weight = decoder_weights + + self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config) + self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy" + self.return_z_loss = config.loss_kwargs.get("return_z_loss", False) + self.unpad_embeddings = config.unpad_embeddings + self.pad_logits = config.pad_logits + self.compile_model = config.compile_model + self.masked_prediction = config.masked_prediction + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _get_aux_loss(self) -> Optional[torch.Tensor]: + """Collect auxiliary losses from all MoE layers in the model.""" + aux_loss = None + + # Traverse all modules to find FlexBertGLUMoE layers + for module in self.modules(): + if hasattr(module, 'aux_loss') and module.aux_loss is not None: + if aux_loss is None: + aux_loss = module.aux_loss + else: + aux_loss = aux_loss + module.aux_loss + + return aux_loss + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + self._init_module_weights(module) + else: + assert isinstance(reset_params, bool) + self.bert._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + + # Output weights. + if not self.config.tie_word_embeddings: + init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out) + + @classmethod + def from_composer( + cls, + pretrained_checkpoint, + state_dict=None, + cache_dir=None, + from_tf=False, + config=None, + *inputs, + **kwargs, + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("FlexBERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings): + self.decoder = new_embeddings + + @torch.no_grad() + def unpad_inputs( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor + ): + return unpad_input(input_ids, attention_mask, position_ids, labels) + + @torch.no_grad() + def pad_inputs( + self, + inputs: torch.Tensor, + indices: torch.Tensor, + batch_size: int, + seqlen: int, + labels: Optional[torch.Tensor] = None, + ignore_index: int = -100, + ): + return pad_input( + inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index + ) + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + # labels should be a `torch.LongTensor` of shape + # `(batch_size, sequence_length)`. These are used for computing the + # masked language modeling loss. + # + # Indices should be in `[-100, 0, ..., config.vocab_size]` (see + # `input_ids` docstring) Tokens with indices set to `-100` are ignored + # (masked), the loss is only computed for the tokens with labels in `[0, + # ..., config.vocab_size]` + # + # Prediction scores are only computed for masked tokens and the (bs, + # seqlen) dimensions are flattened + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None): + batch_size, seq_len = input_ids.shape[:2] + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs( + input_ids, attention_mask, position_ids, labels + ) + + output = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + if self.masked_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + output = output.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.loss_fn.ignore_index + output = output[mask_tokens] + labels = labels[mask_tokens] + + if self.compile_model: + logits = self.compiled_head(output) + else: + logits = self.decoder(self.head(output)) + + loss = None + if labels is not None: + if not self.masked_prediction: + labels = labels.view(-1) + logits = logits.view(labels.shape[0], -1) + + if self.return_z_loss: + loss, z_loss = self.loss_fn(logits, labels) + if self.pad_logits: + return MaskedLMOutputZLoss( + loss=loss, + ce_loss=loss.detach().clone() - z_loss, + z_loss=z_loss, + logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0], + hidden_states=None, + attentions=None, + ) + else: + return MaskedLMOutputZLoss( + loss=loss, + ce_loss=loss.detach().clone() - z_loss, + z_loss=z_loss, + logits=logits, + hidden_states=None, + attentions=None, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + labels=labels, + ) + else: + loss = self.loss_fn(logits, labels) + + if self.pad_logits: + return MaskedLMOutput( + loss=loss, + logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0], + hidden_states=None, + attentions=None, + ) + else: + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + labels=labels, + ) + + def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat( + [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], + dim=-1, + ) + dummy_token = torch.full( + (effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def get_number_parameters( + self, count_embeddings: bool = True, count_decoder: bool = False, trainable: bool = True + ) -> int: + """Returns the number of parameters in the model. + + Args: + count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. + count_decoder: count the parameters in the decoder layer if weights are not tied. + trainable: only count trainable parameters. + """ + params = self.bert.get_number_parameters(count_embeddings, trainable) + params += _count_parameters(self.head, trainable) + if count_decoder and not self.config.tie_word_embeddings: + params += _count_parameters(self.decoder, trainable) + return params + + +class FlexBertForSequenceClassification(FlexBertPreTrainedModel): + """Bert Model transformer with a sequence classification/regression head. + + This head is just a linear layer on top of the pooled output. Used for, + e.g., GLUE tasks. + """ + + def __init__(self, config: FlexBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = FlexBertModel(config) + self.head = FlexBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _get_aux_loss(self) -> Optional[torch.Tensor]: + """Collect auxiliary losses from all MoE layers in the model.""" + aux_loss = None + + # Traverse all modules to find FlexBertGLUMoE layers + for module in self.modules(): + if hasattr(module, 'aux_loss') and module.aux_loss is not None: + if aux_loss is None: + aux_loss = module.aux_loss + else: + aux_loss = aux_loss + module.aux_loss + + return aux_loss + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + self._init_module_weights(module) + else: + assert isinstance(reset_params, bool) + self.bert._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out) + + @classmethod + def from_composer( + cls, + pretrained_checkpoint, + state_dict=None, + cache_dir=None, + from_tf=False, + config=None, + *inputs, + **kwargs, + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. + # Indices should be in `[0, ..., config.num_labels - 1]`. + # If `config.num_labels == 1` a regression loss is computed + # (mean-square loss). If `config.num_labels > 1` a classification loss + # is computed (cross-entropy). + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + pooled_output = self.head(output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + # Compute loss + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + # Add MoE auxiliary losses if present + aux_loss = self._get_aux_loss() + if aux_loss is not None: + loss = loss + aux_loss + + if not return_dict: + output = (logits,) + output + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + + def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: + """Returns the number of parameters in the model. + + Args: + count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. + trainable: only count trainable parameters. + """ + params = self.bert.get_number_parameters(count_embeddings, trainable) + params += _count_parameters(self.head, trainable) + params += _count_parameters(self.classifier, trainable) + return params + + +class FlexBertForMultipleChoice(FlexBertPreTrainedModel): + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """ + + def __init__(self, config: FlexBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = FlexBertModel(config) + self.head = FlexBertPoolingHead(config) + + # In multiple choice tasks, all choices are submitted in a batch, and + # we compute a logit for each option independently. The logits are then + # normalized in the forward pass to get a probability distribution over + # the choices. + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self._init_weights(reset_params=False) + + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): + assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" + if module: + self._init_module_weights(module) + else: + assert isinstance(reset_params, bool) + self.bert._init_weights(reset_params=reset_params) + self.head._init_weights(reset_params=reset_params) + init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out) + + @classmethod + def from_composer( + cls, + pretrained_checkpoint, + state_dict=None, + cache_dir=None, + from_tf=False, + config=None, + *inputs, + **kwargs, + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. + # Indices should be in `[0, ..., config.num_labels - 1]`. + # If `config.num_labels == 1` a regression loss is computed + # (mean-square loss). If `config.num_labels > 1` a classification loss + # is computed (cross-entropy). + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + + output = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + pooled_output = self.head(output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + output + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=None, + attentions=None, + ) + + def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: + """Returns the number of parameters in the model. + + Args: + count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. + trainable: only count trainable parameters. + """ + params = self.bert.get_number_parameters(count_embeddings, trainable) + params += _count_parameters(self.head, trainable) + params += _count_parameters(self.classifier, trainable) + return params + + +def init_model_from_pretrained( + pretrained_model: FlexBertModel, + new_model: FlexBertModel, + mode: Union[str, TileMode] = TileMode.tile_weights_from_middle, +): + """ + Initialize the new model from the pretrained model. + + This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`. + The new model must have the same or more layers and the same or larger dimensions than the pretrained model. + + Args: + pretrained_model (FlexBertModel): The smaller, pre-trained model + new_model (FlexBertModel): The larger model to be initialized + mode (Union[str, TileMode]): The Phi-style weight tiling mode to use + + This function assumes that the new_model has more layers and a larger hidden size + than the pretrained_model, but the same vocabulary size. + """ + + # Tile embeddings + assert isinstance( + new_model.embeddings, type(pretrained_model.embeddings) + ), f"Pretrained and new_model layers must be the same type, got {type(new_model.embeddings)} and {type(pretrained_model.embeddings)}" + assert isinstance( + new_model.embeddings, + (FlexBertAbsoluteEmbeddings, FlexBertSansPositionEmbeddings, FlexBertCompiledSansPositionEmbeddings), + ), f"Unsupported embedding layer type: {type(new_model.embeddings)}" + + tile_embedding(pretrained_model.embeddings.tok_embeddings, new_model.embeddings.tok_embeddings, mode=mode) + if isinstance(pretrained_model.embeddings, FlexBertAbsoluteEmbeddings): + tile_embedding(pretrained_model.embeddings.pos_embeddings, new_model.embeddings.pos_embeddings, mode=mode) + + if hasattr(pretrained_model.embeddings, "norm"): + tile_norm(pretrained_model.embeddings.norm, new_model.embeddings.norm, mode=mode) + + # Tile encoder layers + assert isinstance( + pretrained_model.encoder, (FlexBertUnpadEncoder, FlexBertPaddedEncoder) + ), f"Unsupported encoder layer type: {type(pretrained_model.encoder)}" + assert isinstance( + new_model.encoder, type(pretrained_model.encoder) + ), f"Pretrained and new_model encoder layers must be the same type, got {type(new_model.encoder)} and {type(pretrained_model.encoder)}" + + # Calculate the layer mapping + pretrained_layers = len(pretrained_model.encoder.layers) + new_layers = len(new_model.encoder.layers) + layer_mapping = [round(i * pretrained_layers / new_layers) for i in range(new_layers)] + + # Initialize layers + for new_model_idx, pretrained_idx in enumerate(layer_mapping): + new_model_layer = new_model.encoder.layers[new_model_idx] + pretrained_layer = pretrained_model.encoder.layers[pretrained_idx] + + # first tile the PreNorm/PostNorm layers + assert isinstance( + new_model_layer, type(pretrained_layer) + ), f"Pretrained and new_model prenorm/postnorm layers must be the same type, got {type(new_model_layer)} and {type(pretrained_layer)}" + assert isinstance( + new_model_layer, + ( + FlexBertUnpadPreNormLayer, + FlexBertCompileUnpadPreNormLayer, + FlexBertUnpadParallelPreNormLayer, + FlexBertUnpadPostNormLayer, + FlexBertPaddedPreNormLayer, + FlexBertPaddedParallelPreNormLayer, + FlexBertPaddedPostNormLayer, + ), + ), f"Unsupported prenorm/postnorm layer type: {type(new_model_layer)}" + + # First tile the normalization layers + if hasattr(pretrained_layer, "attn_norm"): + tile_norm(pretrained_layer.attn_norm, new_model_layer.attn_norm, mode=mode) + if hasattr(pretrained_layer, "norm"): + tile_norm(pretrained_layer.norm, new_model_layer.norm, mode=mode) + if hasattr(pretrained_layer, "mlp_norm"): + tile_norm(pretrained_layer.mlp_norm, new_model_layer.mlp_norm, mode=mode) + + # Then tile the attention & mlp layers + assert isinstance( + new_model_layer.attn, type(pretrained_layer.attn) + ), f"Pretrained and new_model attention layers must be the same type, got {type(new_model_layer.attn)} and {type(pretrained_layer.attn)}" + + # first try the parallel attention layers + if isinstance(pretrained_layer, (FlexBertUnpadParallelPreNormLayer, FlexBertPaddedParallelPreNormLayer)): + assert isinstance( + pretrained_layer.attn, + ( + FlexBertUnpadParallelAttention, + FlexBertPaddedParallelAttention, + FlexBertUnpadRopeParallelAttention, + FlexBertPaddedRopeParallelAttention, + ), + ), f"Parallel prenorm layer must have parallel attention layer: {type(pretrained_layer.attn)}" + if not isinstance(pretrained_layer.mlp, (FlexBertParallelGLU)): + raise ValueError(f"Parallel prenorm layer must have parallel MLP layer: {type(pretrained_layer.mlp)}") + tile_linear( + pretrained_layer.Wqkvff, + new_model_layer.Wqkvff, + linear_type=TileLinear.wqkvff, + mode=mode, + pretrained_attn_size=pretrained_layer.attn_size, + pretrained_mlp_size=pretrained_layer.mlp_size, + new_attn_size=new_model_layer.attn_size, + new_mlp_size=new_model_layer.mlp_size, + wqkvff_is_glu=True, + ) + + # then try the fused attention layers + elif isinstance( + pretrained_layer.attn, + ( + FlexBertUnpadAttention, + FlexBertPaddedAttention, + FlexBertUnpadRopeAttention, + FlexBertPaddedRopeAttention, + ), + ): + tile_linear(pretrained_layer.attn.Wqkv, new_model_layer.attn.Wqkv, linear_type=TileLinear.wqkv, mode=mode) + else: + raise ValueError(f"Unsupported attention layer type: {type(pretrained_layer.attn)}") + + # finally, tile the attention output layer + tile_linear(pretrained_layer.attn.Wo, new_model_layer.attn.Wo, linear_type=TileLinear.default, mode=mode) + + # tile the mlp layer if the model is not using parallel attention layers + if not isinstance(pretrained_layer.mlp, (FlexBertMLP, FlexBertGLU, FlexBertParallelGLU)): + raise ValueError(f"Unsupported MLP layer type: {type(pretrained_layer.mlp)}") + assert isinstance( + new_model_layer.mlp, type(pretrained_layer.mlp) + ), f"Pretrained and new_model mlp layers must be the same type, got {type(new_model_layer.mlp)} and {type(pretrained_layer.mlp)}" + + # already tiled the parallel glu layer if it exists, so only need to handle mlp & glu Wi + if isinstance(pretrained_layer.mlp, FlexBertGLU): + tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.glu, mode=mode) + elif isinstance(pretrained_layer.mlp, FlexBertMLP): + tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.default, mode=mode) + # tile the output for both ParallelGLU and MLP/GLU + tile_linear(pretrained_layer.mlp.Wo, new_model_layer.mlp.Wo, linear_type=TileLinear.default, mode=mode) + + +def init_mlm_model_from_pretrained( + config: FlexBertConfig, + pretrained_model: FlexBertForMaskedLM, + new_model: FlexBertForMaskedLM, + mode: Union[str, TileMode] = TileMode.tile_weights_from_middle, +): + """ + Initialize the new model from the pretrained model. + + This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`. + The new model must have the same or more layers and the same or larger dimensions than the pretrained model. + + Args: + config (FlexBertConfig): The configuration of the new_model + pretrained_model (FlexBertForMaskedLM): The smaller, pre-trained model + new_model (FlexBertForMaskedLM): The larger model to be initialized from the pretrained model + mode (Union[str, TileMode]): The Phi-style weight tiling mode to use + + This function assumes that the new_model has more layers and a larger hidden size + than the pretrained_model, but the same vocabulary size. + """ + init_model_from_pretrained(pretrained_model.bert, new_model.bert, mode=mode) + + # TODO: uncomment this when the repo is turned into a pip installable package + # if not isinstance(pretrained_model.head, FlexBertPredictionHead): + # raise ValueError(f"Pretrained model must have a prediction head: {type(pretrained_model.head)}") + # if not isinstance(new_model.head, FlexBertPredictionHead): + # raise ValueError(f"New model must have a prediction head: {type(new_model.head)}") + + # tile the prediction head + tile_linear(pretrained_model.head.dense, new_model.head.dense, linear_type=TileLinear.default, mode=mode) + tile_norm(pretrained_model.head.norm, new_model.head.norm, mode=mode) + + # setup weight tying + if config.tie_word_embeddings: + new_model.decoder.weight = new_model.bert.embeddings.tok_embeddings.weight + tile_linear( + pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True + ) + else: + tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode) From d05272c39d47b69c3b14311093ccbaef34577ac0 Mon Sep 17 00:00:00 2001 From: heyo66 Date: Thu, 11 Dec 2025 18:31:27 +0100 Subject: [PATCH 2/4] Refactor BERT model implementation and remove redundant files --- configuration_bert.py | 290 ----- loss.py | 105 -- main.py | 0 mlp.py | 405 ------ model.py | 1722 ------------------------- src/bert_layers/configuration_bert.py | 23 +- src/bert_layers/loss.py | 75 ++ src/bert_layers/mlp.py | 191 +++ src/bert_layers/model.py | 38 + 9 files changed, 326 insertions(+), 2523 deletions(-) delete mode 100644 configuration_bert.py delete mode 100644 loss.py mode change 100755 => 100644 main.py delete mode 100644 mlp.py delete mode 100644 model.py diff --git a/configuration_bert.py b/configuration_bert.py deleted file mode 100644 index b9f1e6ce..00000000 --- a/configuration_bert.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright 2022 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -import warnings - -from transformers import BertConfig as TransformersBertConfig - - -class BertConfig(TransformersBertConfig): - def __init__( - self, - alibi_starting_size: int = 512, - normalization: str = "layernorm", - attention_probs_dropout_prob: float = 0.0, - head_pred_act: str = "gelu", - deterministic_fa2: bool = False, - allow_embedding_resizing: bool = False, - **kwargs, - ): - """Configuration class for MosaicBert. - - Args: - alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to - create when initializing the model. You should be able to ignore this parameter in most cases. - Defaults to 512. - attention_probs_dropout_prob (float): By default, turn off attention dropout in MosaicBERT - Note that the custom Triton Flash Attention with ALiBi implementation does not support droput. - However, Flash Attention 2 supports ALiBi and dropout https://github.com/Dao-AILab/flash-attention - embed_dropout_prob (float): Dropout probability for the embedding layer. - attn_out_dropout_prob (float): Dropout probability for the attention output layer. - mlp_dropout_prob (float): Dropout probability for the MLP layer. - allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size. - """ - super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) - self.alibi_starting_size = alibi_starting_size - self.normalization = normalization - self.head_pred_act = head_pred_act - self.deterministic_fa2 = deterministic_fa2 - self.allow_embedding_resizing = allow_embedding_resizing - - -class FlexBertConfig(TransformersBertConfig): - def __init__( - self, - attention_layer: str = "base", - attention_probs_dropout_prob: float = 0.0, - attn_out_bias: bool = False, - attn_out_dropout_prob: float = 0.0, - attn_qkv_bias: bool = False, - bert_layer: str = "prenorm", - decoder_bias: bool = True, - embed_dropout_prob: float = 0.0, - embed_norm: bool = True, - final_norm: bool = False, - embedding_layer: str = "absolute_pos", - encoder_layer: str = "base", - loss_function: str = "cross_entropy", - loss_kwargs: dict = {}, - mlp_dropout_prob: float = 0.0, - mlp_in_bias: bool = False, - mlp_layer: str = "glu_moe", - mlp_out_bias: bool = False, - norm_kwargs: dict = {}, - normalization: str = "rmsnorm", - padding: str = "unpadded", - head_class_act: str = "silu", - head_class_bias: bool = False, - head_class_dropout: float = 0.0, - head_class_norm: str = False, - head_pred_act: str = "silu", - head_pred_bias: bool = False, - head_pred_dropout: float = 0.0, - head_pred_norm: bool = True, - pooling_type: str = "cls", - rotary_emb_dim: int | None = None, - rotary_emb_base: float = 10000.0, - rotary_emb_scale_base=None, - rotary_emb_interleaved: bool = False, - use_fa2: bool = True, - use_sdpa_attn_mask: bool = False, - allow_embedding_resizing: bool = False, - init_method: str = "default", - init_std: float = 0.02, - init_cutoff_factor: float = 2.0, - init_small_embedding: bool = False, - initial_attention_layer: str | None = None, - initial_bert_layer: str | None = None, - initial_mlp_layer: str | None = None, - num_initial_layers: int = 1, - skip_first_prenorm: bool = False, - deterministic_fa2: bool = False, - sliding_window: int = -1, - global_attn_every_n_layers: int = -1, - local_attn_rotary_emb_base: float = -1, - local_attn_rotary_emb_dim: int | None = None, - unpad_embeddings: bool = False, - pad_logits: bool = False, - compile_model: bool = False, - masked_prediction: bool = False, - moe_num_experts: int = 8, - moe_top_k: int = 2, - moe_use_noisy_top_k: bool = True, - moe_capacity_factor: float = 1.25, - moe_compute_aux_loss: bool = True, - moe_load_balance_loss_weight: float = 0.01, - moe_router_z_loss_weight: float = 0.001, - **kwargs, - ): - """ - Args: - attention_layer (str): Attention layer type. - attention_probs_dropout_prob (float): Dropout probability for attention probabilities. - attn_out_bias (bool): use bias in attention output projection. - attn_out_dropout_prob (float): Dropout probability for attention output. - attn_qkv_bias (bool): use bias for query, key, value linear layer(s). - bert_layer (str): BERT layer type. - decoder_bias (bool): use bias in decoder linear layer. - embed_dropout_prob (float): Dropout probability for embeddings. - embed_norm (bool): Normalize embedding output. - final_norm (bool): Add normalization after the final encoder layer and before head. - embedding_layer (str): Embedding layer type. - encoder_layer (str): Encoder layer type. - loss_function (str): Loss function to use. - loss_kwargs (dict): Keyword arguments for loss function. - mlp_dropout_prob (float): Dropout probability for MLP layers. - mlp_in_bias (bool): Use bias in MLP input linear layer. - mlp_layer (str): MLP layer type. - mlp_out_bias (bool): Use bias in MLP output linear layer. - norm_kwargs (dict): Keyword arguments for normalization layers. - normalization (str): Normalization type. - padding (str): Unpad inputs. Best with `use_fa2=True`. - head_class_act (str): Activation function for classification head. - head_class_bias (bool): Use bias in classification head linear layer(s). - head_class_dropout (float): Dropout probability for classification head. - head_class_norm (str): Normalization type for classification head. - head_pred_act (str): Activation function for prediction head. - head_pred_bias (bool): Use bias in prediction head linear layer(s). - head_pred_dropout (float): Dropout probability for prediction head. - head_pred_norm (bool): Normalize prediction head output. - pooling_type (str): Pooling type. - rotary_emb_dim (int | None): Rotary embedding dimension. - rotary_emb_base (float): Rotary embedding base. - rotary_emb_scale_base (float): Rotary embedding scale base. - rotary_emb_interleaved (bool): Use interleaved rotary embeddings. - use_fa2 (bool): Use FlashAttention2. Requires flash_attn package. - use_sdpa_attn_mask (bool): Pass a mask to SDPA. This will prevent SDPA from using the PyTorch FA2 kernel. - allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size. - init_method (str): Model layers initialization method. - init_std (float): Standard deviation for initialization. Used for normal and full_megatron init. - init_cutoff_factor (float): Cutoff factor for initialization. Used for normal and full_megatron init. - init_small_embedding (bool): Initialize embeddings with RWKV small init. - initial_attention_layer (str | None): Replace first `num_initial_layers` attention_layer instance with this layer. - initial_bert_layer (str | None): Replace first `num_initial_layers` bert_layer instance with this layer. - initial_mlp_layer (str | None): Replace first `num_initial_layers` mlp_layer instance with this layer. - num_initial_layers (int): Number of initial layers to set via `initial_attention_layer`, `initial_bert_layer`, and `initial_mlp_layer`. - skip_first_prenorm (bool): Skip pre-normalization for the first bert layer. Requires `embed_norm=True`. - deterministic_fa2 (bool): Use Flash Attention 2 deterministic mode. This is slower then the default non-deterministic mode. - sliding_window (int): Use sliding window attention with window size `n`. -1 to disable. Window size split between the left and right context. Only supports FA2. - global_attn_every_n_layers (int): Use global attention every `n` layers and sliding window for the rest. -1 to disable. - local_attn_rotary_emb_base (float): Rotary embedding base for local attention. -1 to disable and use `rotary_emb_base` for all layers. - local_attn_rotary_emb_dim (int | None): Rotary embedding dimension for local attention. None to disable and use `rotary_emb_dim` for all layers. - unpad_embeddings (bool): Unpad inputs before the embedding layer. - pad_logits (bool): Pad logits after the calculating the loss. - compile_model (bool): Compile the subset of the model which can be compiled. - masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers - moe_num_experts (int): Number of experts for Mixture of Experts layers. - moe_top_k (int): Number of top experts to select for each token in MoE. - moe_use_noisy_top_k (bool): Use noisy top-k gating for exploration during training. - moe_capacity_factor (float): Capacity factor for expert assignment in MoE. - moe_compute_aux_loss (bool): Whether to compute and add auxiliary losses for MoE. - moe_load_balance_loss_weight (float): Weight for the load balancing auxiliary loss. - moe_router_z_loss_weight (float): Weight for the router z-loss auxiliary loss. - **kwargs: Additional keyword arguments. - """ - super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) - self.attention_layer = attention_layer - self.attn_out_bias = attn_out_bias - self.attn_out_dropout_prob = attn_out_dropout_prob - self.attn_qkv_bias = attn_qkv_bias - self.bert_layer = bert_layer - self.decoder_bias = decoder_bias - self.embed_dropout_prob = embed_dropout_prob - self.embed_norm = embed_norm - self.final_norm = final_norm - self.embedding_layer = embedding_layer - self.encoder_layer = encoder_layer - self.loss_function = loss_function - self.loss_kwargs = loss_kwargs - self.mlp_dropout_prob = mlp_dropout_prob - self.mlp_in_bias = mlp_in_bias - self.mlp_layer = mlp_layer - self.mlp_out_bias = mlp_out_bias - self.norm_kwargs = norm_kwargs - self.normalization = normalization - self.padding = padding - self.head_class_act = head_class_act - self.head_class_bias = head_class_bias - self.head_class_dropout = head_class_dropout - self.head_class_norm = head_class_norm - self.head_pred_act = head_pred_act - self.head_pred_bias = head_pred_bias - self.head_pred_dropout = head_pred_dropout - self.head_pred_norm = head_pred_norm - self.pooling_type = pooling_type - self.rotary_emb_dim = rotary_emb_dim - self.rotary_emb_base = rotary_emb_base - self.rotary_emb_scale_base = rotary_emb_scale_base - self.rotary_emb_interleaved = rotary_emb_interleaved - self.use_fa2 = use_fa2 - self.use_sdpa_attn_mask = use_sdpa_attn_mask - self.allow_embedding_resizing = allow_embedding_resizing - self.init_method = init_method - self.init_std = init_std - self.init_cutoff_factor = init_cutoff_factor - self.init_small_embedding = init_small_embedding - self.initial_attention_layer = initial_attention_layer - self.initial_bert_layer = initial_bert_layer - self.initial_mlp_layer = initial_mlp_layer - self.num_initial_layers = num_initial_layers - self.skip_first_prenorm = skip_first_prenorm - self.deterministic_fa2 = deterministic_fa2 - self.sliding_window = sliding_window - self.global_attn_every_n_layers = global_attn_every_n_layers - self.local_attn_rotary_emb_base = local_attn_rotary_emb_base - self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim - self.unpad_embeddings = unpad_embeddings - self.pad_logits = pad_logits - self.compile_model = compile_model - self.masked_prediction = masked_prediction - self.moe_num_experts = moe_num_experts - self.moe_top_k = moe_top_k - self.moe_use_noisy_top_k = moe_use_noisy_top_k - self.moe_capacity_factor = moe_capacity_factor - self.moe_compute_aux_loss = moe_compute_aux_loss - self.moe_load_balance_loss_weight = moe_load_balance_loss_weight - self.moe_router_z_loss_weight = moe_router_z_loss_weight - - if loss_kwargs.get("return_z_loss", False): - if loss_function != "fa_cross_entropy": - raise ValueError("loss_function must be 'fa_cross_entropy' when return_z_loss is True") - if loss_kwargs.get("lse_square_scale", 0) <= 0: - raise ValueError( - "lse_square_scale must be passed to `loss_kwargs` and must be greater than 0 for z_loss" - ) - if loss_kwargs.get("inplace_backward", False): - self.loss_kwargs["inplace_backward"] = False - warnings.warn("`inplace_backward=True` will cause incorrect metrics. Automatically setting to False.") - - if global_attn_every_n_layers > 0 and (self.num_hidden_layers - 1) % global_attn_every_n_layers != 0: - raise ValueError( - f"{global_attn_every_n_layers=} must be a divisor of one less than {self.num_hidden_layers=}" - ) - - if self.sliding_window != -1: - if not self.use_fa2: - raise ValueError("Sliding window attention is only supported with FlashAttention2") - if self.sliding_window % 2 != 0 and self.sliding_window % 64 != 0: - raise ValueError( - f"Sliding window must be an even number and divisible by 64: {self.sliding_window=} {self.sliding_window % 64} {self.sliding_window % 2}" - ) - else: - if self.global_attn_every_n_layers != -1: - raise ValueError("global_attn_every_n_layers must be -1 when sliding_window is disabled") - if self.local_attn_rotary_emb_base != -1: - raise ValueError("local_attn_rotary_emb_base must be -1 when sliding_window is disabled") - if self.local_attn_rotary_emb_dim is not None: - raise ValueError("local_attn_rotary_emb_dim must be None when sliding_window is disabled") - - if self.unpad_embeddings and self.padding != "unpadded": - warnings.warn( - "`unpad_embeddings=True` requires `padding='unpadded'`. Automatically setting `padding='unpadded'`." - ) - self.padding = "unpadded" - if self.pad_logits and not self.unpad_embeddings: - raise ValueError("`pad_logits=True` requires `unpad_embeddings=True`") - if self.unpad_embeddings and self.embedding_layer == "absolute_pos": - raise ValueError(f"{self.unpad_embeddings=} is incompatible with {self.embedding_layer=}") - - -PADDING = ["unpadded", "padded"] - - -def maybe_add_padding(config: FlexBertConfig, config_option: str) -> str: - if config.padding not in PADDING: - raise ValueError(f"Invalid padding type: {config.padding}, must be one of {PADDING}") - - if not any(config_option.startswith(pad + "_") for pad in PADDING): - config_option = f"{config.padding}_{config_option}" - - return config_option diff --git a/loss.py b/loss.py deleted file mode 100644 index 3f210a4c..00000000 --- a/loss.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2024 onwards Answer.AI, LightOn, and contributors -# License: Apache-2.0 - -import inspect -import torch -import torch.nn as nn -import torch.nn.functional as F -from .configuration_bert import FlexBertConfig - -try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss -except ImportError: - CrossEntropyLoss = None - -LOSS2CLS = { - "cross_entropy": nn.CrossEntropyLoss, - "binary_cross_entropy": nn.BCEWithLogitsLoss, - "mean_squared_error": nn.MSELoss, -} - -if CrossEntropyLoss is not None: - LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss - - -class MoELoadBalancingLoss(nn.Module): - """Computes Switch Transformer auxiliary loss for load balancing. - - Reference: https://arxiv.org/abs/2101.03961 (equations 4-6, page 7) - - This loss encourages balanced token allocation across experts to avoid - scenarios where some experts are overloaded while others are underutilized. - """ - - def __init__(self, num_experts: int, top_k: int = 2): - super().__init__() - self.num_experts = num_experts - self.top_k = top_k - - def forward( - self, - router_logits: torch.Tensor, - expert_indices: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - router_logits: Router logits [batch_size, seq_len, num_experts] - expert_indices: Top-k expert indices [batch_size, seq_len, top_k] - - Returns: - load_balance_loss: Scalar loss value - """ - # Compute expert probabilities - expert_probs = F.softmax(router_logits, dim=-1) # [B, C, n_exp] - - # Equation (5): compute ratio of tokens allocated to each expert - with torch.no_grad(): - one_hot_indices = F.one_hot(expert_indices, num_classes=self.num_experts) # [B, C, K, n_exp] - one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, C, n_exp] - tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1)) # [n_exp] - - # Equation (6): compute ratio of router probability allocated to each expert - prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1)) # [n_exp] - - # Equation (4): scaled dot product between prob / token allocation vectors - load_balance_loss = self.num_experts * torch.sum(prob_per_expert * tokens_per_expert) - - return load_balance_loss - - -class MoERouterZLoss(nn.Module): - """Computes router z-loss for MoE models. - - Reference: https://arxiv.org/abs/2202.08906 (equation 5, page 7) - - This loss constrains the size of router logits to avoid numerical instability - during training. Large logits can lead to round-off errors in the softmax computation, - even in float32 precision. - """ - - def forward(self, router_logits: torch.Tensor) -> torch.Tensor: - """ - Args: - router_logits: Router logits [batch_size, seq_len, num_experts] - - Returns: - router_z_loss: Scalar loss value - """ - # Numerically stable computation: logsumexp is equivalent to log(sum(exp(x))) - # This avoids overflow issues from directly exponentiating large logits - router_z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, C] - - # Average over all tokens - router_z_loss = torch.mean(router_z_loss) - - return router_z_loss - - -def get_loss_fn(config: FlexBertConfig) -> nn.Module: - try: - loss_class = LOSS2CLS[config.loss_function] - signature = inspect.signature(loss_class) - loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters} - return loss_class(**loss_kwargs) - except KeyError: - raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.") diff --git a/main.py b/main.py old mode 100755 new mode 100644 diff --git a/mlp.py b/mlp.py deleted file mode 100644 index f344008d..00000000 --- a/mlp.py +++ /dev/null @@ -1,405 +0,0 @@ -# Copyright 2024 onwards Answer.AI, LightOn, and contributors -# License: Apache-2.0 - -# Copyright 2022 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2023 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2023, Tri Dao. - -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .configuration_bert import FlexBertConfig -from .activation import get_act_fn -from .normalization import get_norm_layer -from .initialization import ModuleType, init_weights -from .loss import MoELoadBalancingLoss, MoERouterZLoss - - -class BertResidualGLU(nn.Module): - """Applies the FFN at the end of each Mosaic BERT layer. - - Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` - and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but - introduces Gated Linear Units. - - Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a - standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with - `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed - with the `config.intermediate_size=3072`. - However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased - parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. - """ - - def __init__( - self, - config, - ): - super().__init__() - self.config = config - self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) - self.act = get_act_fn(config.hidden_act) - self.wo = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.layernorm = get_norm_layer(config) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Compute new hidden states from current hidden states. - - Args: - hidden_states (torch.Tensor): The (unpadded) hidden states from - the attention layer [nnz, dim]. - """ - residual_connection = hidden_states - # compute the activation - hidden_states = self.gated_layers(hidden_states) - gated = hidden_states[:, : self.config.intermediate_size] - non_gated = hidden_states[:, self.config.intermediate_size :] - hidden_states = self.act(gated) * non_gated - hidden_states = self.dropout(hidden_states) - # multiply by the second matrix - hidden_states = self.wo(hidden_states) - # add the residual connection and post-LN - hidden_states = self.layernorm(hidden_states + residual_connection) - return hidden_states - - -class FlexBertMLPBase(nn.Module): - """A FlexBERT MLP base class for type hints.""" - - def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): - super().__init__() - self.config = config - self.layer_id = layer_id - - def _init_weights(self, reset_params: bool = False): - raise NotImplementedError("This is a base class and should not be used directly.") - - def reset_parameters(self): - self._init_weights(reset_params=True) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("This is a base class and should not be used directly.") - - -class FlexBertMLP(FlexBertMLPBase): - """Applies the MLP at the end of each FlexBERT layer. - - Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` - and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. - """ - - def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): - super().__init__(config=config, layer_id=layer_id) - self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias) - self.act = get_act_fn(config.hidden_act) - self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() - self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) - - def _init_weights(self, reset_params: bool = False): - init_weights( - self.config, - self.Wi, - layer_dim=self.config.hidden_size, - layer_id=None, - type_of_module=ModuleType.in_module, - ) - init_weights( - self.config, - self.Wo, - layer_dim=self.config.intermediate_size, - layer_id=self.layer_id, - type_of_module=ModuleType.out_module, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Compute new hidden states from current hidden states. - - Args: - hidden_states (torch.Tensor): The (unpadded) hidden states from - the attention layer [nnz, dim]. - """ - return self.Wo(self.drop(self.act(self.Wi(hidden_states)))) - - -class FlexBertGLU(FlexBertMLPBase): - """Applies the GLU at the end of each FlexBERT layer. - - Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` - and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. - """ - - def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): - super().__init__(config=config, layer_id=layer_id) - self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias) - self.act = get_act_fn(config.hidden_act) - self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() - self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) - - def _init_weights(self, reset_params: bool = False): - init_weights( - self.config, - self.Wi, - layer_dim=self.config.hidden_size, - layer_id=None, - type_of_module=ModuleType.in_module, - ) - init_weights( - self.config, - self.Wo, - layer_dim=self.config.intermediate_size, - layer_id=self.layer_id, - type_of_module=ModuleType.out_module, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input, gate = self.Wi(hidden_states).chunk(2, dim=-1) - return self.Wo(self.drop(self.act(input) * gate)) - - -class FlexBertParallelGLU(FlexBertMLPBase): - """Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention. - - Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` - and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. - """ - - def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): - super().__init__(config=config, layer_id=layer_id) - self.act = get_act_fn(config.hidden_act) - self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() - self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) - - def _init_weights(self, reset_params: bool = False): - init_weights( - self.config, - self.Wo, - layer_dim=self.config.intermediate_size, - layer_id=self.layer_id, - type_of_module=ModuleType.out_module, - ) - - def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: - input, gate = intermediate_ff.chunk(2, dim=-1) - return self.Wo(self.drop(self.act(input) * gate)) - - -class Router(nn.Module): - """Top-K router for selecting experts.""" - - def __init__( - self, - d: int, - n_exp: int, - top_k: int = 2, - use_noisy_top_k: bool = True, - capacity_factor: float = 1.25, - ): - super().__init__() - self.d = d - self.n_exp = n_exp - self.top_k = top_k - self.use_noisy_top_k = use_noisy_top_k - self.capacity_factor = capacity_factor - - # Router weights to compute logits for each expert - self.gate = nn.Linear(d, n_exp, bias=False) - - # Noise parameters for noisy top-k gating - if use_noisy_top_k: - self.w_noise = nn.Linear(d, n_exp, bias=False) - - def forward(self, x: torch.Tensor): - """ - Args: - x: [batch_size, seq_len, d] - Returns: - exp_weight: Expert weights [batch_size * seq_len, top_k] - exp_mask: Expert mask [batch_size * seq_len, n_exp, exp_capacity] - exp_batches: Token assignments [n_exp, exp_capacity, d] - """ - B, C, d = x.size() - num_tokens = B * C - x_flat = x.view(num_tokens, d) - - # Compute router logits - logits = self.gate(x_flat) # [num_tokens, n_exp] - - # Add noise for exploration (optional) - if self.use_noisy_top_k and self.training: - noise_stddev = F.softplus(self.w_noise(x_flat)) - noise = torch.randn_like(logits) * noise_stddev - logits = logits + noise - - # Select top-k experts - top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) - top_k_gates = F.softmax(top_k_logits, dim=-1) # [num_tokens, top_k] - - # Compute expert capacity - exp_capacity = int((num_tokens * self.top_k * self.capacity_factor) / self.n_exp) - - # Create expert assignment mask and batches - exp_mask = torch.zeros(num_tokens, self.n_exp, exp_capacity, device=x.device) - exp_batches = torch.zeros(self.n_exp, exp_capacity, d, device=x.device) - - # Count tokens assigned to each expert - expert_counts = torch.zeros(self.n_exp, dtype=torch.long, device=x.device) - - # Assign tokens to experts - for token_idx in range(num_tokens): - for k_idx in range(self.top_k): - expert_idx = top_k_indices[token_idx, k_idx] - if expert_counts[expert_idx] < exp_capacity: - pos = expert_counts[expert_idx] - exp_mask[token_idx, expert_idx, pos] = top_k_gates[token_idx, k_idx] - exp_batches[expert_idx, pos] = x_flat[token_idx] - expert_counts[expert_idx] += 1 - - return top_k_gates, exp_mask, exp_batches - - - - - -class FlexBertGLUMoE(FlexBertMLPBase): - """Mixture of Experts with GLU activation for FlexBERT.""" - - def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): - super().__init__(config=config, layer_id=layer_id) - - self.n_exp = getattr(config, 'moe_num_experts', 4) - self.top_k = getattr(config, 'moe_top_k', 2) - self.use_noisy_top_k = getattr(config, 'moe_use_noisy_top_k', True) - self.capacity_factor = getattr(config, 'moe_capacity_factor', 1.25) - self.compute_aux_loss = getattr(config, 'moe_compute_aux_loss', True) - self.load_balance_loss_weight = getattr(config, 'moe_load_balance_loss_weight', 0.01) - self.router_z_loss_weight = getattr(config, 'moe_router_z_loss_weight', 0.001) - - self.router = Router( - d=config.hidden_size, - n_exp=self.n_exp, - top_k=self.top_k, - use_noisy_top_k=self.use_noisy_top_k, - capacity_factor=self.capacity_factor, - ) - - # GLU experts (each projects to 2x intermediate size) - self.experts = nn.ModuleList([ - nn.Sequential( - nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=config.mlp_in_bias), - nn.Identity(), # Placeholder for chunking + activation - nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity(), - nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias), - ) - for _ in range(self.n_exp) - ]) - self.act = get_act_fn(config.hidden_act) - - # Initialize auxiliary loss modules - if self.compute_aux_loss: - self.load_balance_loss = MoELoadBalancingLoss(num_experts=self.n_exp, top_k=self.top_k) - self.router_z_loss = MoERouterZLoss() - - def _init_weights(self, reset_params: bool = False): - init_weights( - self.config, - self.router.gate, - layer_dim=self.config.hidden_size, - layer_id=self.layer_id, - type_of_module=ModuleType.in_module, - ) - - for expert in self.experts: - for i, module in enumerate(expert): - if isinstance(module, nn.Linear): - init_weights( - self.config, - module, - layer_dim=self.config.hidden_size if i == 0 else self.config.intermediate_size, - layer_id=self.layer_id, - type_of_module=ModuleType.in_module if i == 0 else ModuleType.out_module, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - original_shape = hidden_states.shape - if hidden_states.dim() == 2: - hidden_states = hidden_states.unsqueeze(0) - - B, C, d = hidden_states.size() - num_tokens = B * C - x_flat = hidden_states.view(num_tokens, d) - - # Compute router logits for auxiliary loss calculation - router_logits = self.router.gate(x_flat) # [num_tokens, n_exp] - - exp_weight, exp_mask, exp_batches = self.router(hidden_states) - - # Extract top-k indices from router for load balancing loss - _, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) # [num_tokens, top_k] - - # Apply GLU experts - exp_out = torch.zeros_like(exp_batches) - for i, expert in enumerate(self.experts): - x = expert[0](exp_batches[i]) # Linear projection - input, gate = x.chunk(2, dim=-1) # Split for GLU - x = self.act(input) * gate # GLU activation - x = expert[2](x) # Dropout - exp_out[i] = expert[3](x) # Output projection - - exp_weight_flat = exp_mask.view(num_tokens, -1) - exp_out_flat = exp_out.view(-1, d) - output = torch.matmul(exp_weight_flat, exp_out_flat) - - # Compute auxiliary losses - self.aux_loss = None - if self.compute_aux_loss: - # Reshape for loss computation - router_logits_reshaped = router_logits.view(B, C, -1) - top_k_indices_reshaped = top_k_indices.view(B, C, -1) - - # Compute load balancing loss - lb_loss = self.load_balance_loss(router_logits_reshaped, top_k_indices_reshaped) - - # Compute router z-loss - z_loss = self.router_z_loss(router_logits_reshaped) - - # Combine auxiliary losses with weights - self.aux_loss = self.load_balance_loss_weight * lb_loss + self.router_z_loss_weight * z_loss - - return output.view(*original_shape) - - -# Update the MLP registry -MLP2CLS = { - "mlp": FlexBertMLP, - "glu": FlexBertGLU, - "parallel_glu": FlexBertParallelGLU, - "glu_moe": FlexBertGLUMoE, -} - - -def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase: - try: - mlp_layer = ( - config.initial_mlp_layer - if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None - else config.mlp_layer - ) - return MLP2CLS[mlp_layer](config, layer_id=layer_id) - except KeyError as e: - if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None: - raise ValueError( - f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}" - ) - else: - raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") - - diff --git a/model.py b/model.py deleted file mode 100644 index aeac1b01..00000000 --- a/model.py +++ /dev/null @@ -1,1722 +0,0 @@ -# Copyright 2024 onwards Answer.AI, LightOn, and contributors -# License: Apache-2.0 - -# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation) -# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT - -# Copyright 2022 Jonas Geiping -# License: MIT - -# Copyright 2022 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2023 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2023, Tri Dao. - -"""Implements Mosaic BERT, with an eye towards the Hugging Face API. - -Mosaic BERT improves performance over Hugging Face BERT through the following: - -1. ALiBi. This architectural change removes positional embeddings and instead encodes positional -information through attention biases based on query-key position distance. It improves the effectiveness -of training with shorter sequence lengths by enabling extrapolation to longer sequences. - -2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer -to improve overall expressiveness, providing better convergence properties. - -3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically -improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that -supports attention biases, which allows us to use Flash Attention with ALiBi. - -4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT -implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation -and improve speed. It does this without changing how the user interfaces with the model, thereby -preserving the simple API of standard implementations. - - -Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence -classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases. - -See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage -of the core Mosaic BERT classes. -""" - -import logging -import os -import sys -import warnings -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from -sys.path.append(os.path.dirname(os.path.realpath(__file__))) - -import torch -import torch.nn as nn -from einops import rearrange -from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present -from transformers.modeling_outputs import ( - MaskedLMOutput, - ModelOutput, - MultipleChoiceModelOutput, - SequenceClassifierOutput, -) -from transformers.models.bert.modeling_bert import BertPreTrainedModel - -from bert_padding import index_put_first_axis - -from src.bert_layers.activation import get_act_fn -from src.bert_layers.attention import ( - FlexBertPaddedAttention, - FlexBertPaddedParallelAttention, - FlexBertPaddedRopeAttention, - FlexBertPaddedRopeParallelAttention, - FlexBertUnpadAttention, - FlexBertUnpadParallelAttention, - FlexBertUnpadRopeAttention, - FlexBertUnpadRopeParallelAttention, -) -from src.bert_layers.configuration_bert import FlexBertConfig -from src.bert_layers.embeddings import ( - BertAlibiEmbeddings, - FlexBertAbsoluteEmbeddings, - FlexBertCompiledSansPositionEmbeddings, - FlexBertSansPositionEmbeddings, - get_embedding_layer, -) -from src.bert_layers.initialization import ( - ModuleType, - TileLinear, - TileMode, - init_weights, - tile_embedding, - tile_linear, - tile_norm, -) -from src.bert_layers.layers import ( - BertAlibiEncoder, - BertPooler, - BertPredictionHeadTransform, - FlexBertCompileUnpadPreNormLayer, - FlexBertPaddedEncoder, - FlexBertPaddedParallelPreNormLayer, - FlexBertPaddedPostNormLayer, - FlexBertPaddedPreNormLayer, - FlexBertUnpadEncoder, - FlexBertUnpadParallelPreNormLayer, - FlexBertUnpadPostNormLayer, - FlexBertUnpadPreNormLayer, - get_encoder_layer, -) -from src.bert_layers.loss import get_loss_fn -from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU -from src.bert_layers.normalization import get_norm_layer -from src.bert_layers.padding import pad_input, unpad_input - -logger = logging.getLogger(__name__) - - -def _count_parameters(model: nn.Module, trainable: bool = True) -> int: - if trainable: - return sum(p.numel() for p in model.parameters() if p.requires_grad) - else: - return sum(p.numel() for p in model.parameters()) - - -class BertModel(BertPreTrainedModel): - """Overall BERT model. - - Args: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - model = BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__( - self, - config, - add_pooling_layer: bool = True, - ): - super(BertModel, self).__init__(config) - self.embeddings = BertAlibiEmbeddings(config) - self.encoder = BertAlibiEncoder(config) - self.pooler = BertPooler(config) if add_pooling_layer else None - self.post_init() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def forward( - self, - input_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_all_encoded_layers: Optional[bool] = False, - masked_tokens_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) - - subset_mask = [] - first_col_mask = [] - - if masked_tokens_mask is None: - subset_mask = None - else: - first_col_mask = torch.zeros_like(masked_tokens_mask) - first_col_mask[:, 0] = True - subset_mask = masked_tokens_mask | first_col_mask - - encoder_outputs = self.encoder( - embedding_output, - attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - subset_mask=subset_mask, - ) - - if masked_tokens_mask is None: - sequence_output = encoder_outputs[-1] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - else: - # TD [2022-03-01]: the indexing here is very tricky. - attention_mask_bool = attention_mask.bool() - subset_idx = subset_mask[attention_mask_bool] # type: ignore - sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]] - if self.pooler is not None: - pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]] - pooled_output = self.pooler(pool_input, pool=False) - else: - pooled_output = None - - if not output_all_encoded_layers: - encoder_outputs = sequence_output - - if self.pooler is not None: - return encoder_outputs, pooled_output - - return encoder_outputs, None - - -################### -# Bert Heads -################### -class BertLMPredictionHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super().__init__() - self.transform = BertPredictionHeadTransform(config) - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0)) - self.decoder.weight = bert_model_embedding_weights - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super().__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - - def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - def __init__(self, config): - super().__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -##################### -# Various Bert models -##################### - - -class BertForPreTraining(BertPreTrainedModel): - # TBD: Coming in Future Commit - pass - - -class BertLMHeadModel(BertPreTrainedModel): - # TBD: Coming in Future Commit - pass - - -class BertForMaskedLM(BertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - if config.is_decoder: - warnings.warn( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def from_composer( - cls, - pretrained_checkpoint, - state_dict=None, - cache_dir=None, - from_tf=False, - config=None, - *inputs, - **kwargs, - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: - # labels should be a `torch.LongTensor` of shape - # `(batch_size, sequence_length)`. These are used for computing the - # masked language modeling loss. - # - # Indices should be in `[-100, 0, ..., config.vocab_size]` (see - # `input_ids` docstring) Tokens with indices set to `-100` are ignored - # (masked), the loss is only computed for the tokens with labels in `[0, - # ..., config.vocab_size]` - # - # Prediction scores are only computed for masked tokens and the (bs, - # seqlen) dimensions are flattened - if (input_ids is not None) == (inputs_embeds is not None): - raise ValueError("Must specify either input_ids or input_embeds!") - - if labels is None: - masked_tokens_mask = None - else: - masked_tokens_mask = labels > 0 - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - masked_tokens_mask=masked_tokens_mask, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - loss = None - if labels is not None: - # Compute loss - loss_fct = nn.CrossEntropyLoss() - masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() - loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx]) - - # Add MoE auxiliary losses if present - aux_loss = self._get_aux_loss() - if aux_loss is not None: - loss = loss + aux_loss - - assert input_ids is not None, "Coding error; please open an issue" - batch, seqlen = input_ids.shape[:2] - prediction_scores = rearrange( - index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen), - "(b s) d -> b s d", - b=batch, - ) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return MaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=None, - attentions=None, - ) - - def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs): - input_shape = input_ids.shape - effective_batch_size = input_shape[0] - - # add a dummy token - if self.config.pad_token_id is None: - raise ValueError("The PAD token should be defined for generation") - - attention_mask = torch.cat( - [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], - dim=-1, - ) - dummy_token = torch.full( - (effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device, - ) - input_ids = torch.cat([input_ids, dummy_token], dim=1) - - return {"input_ids": input_ids, "attention_mask": attention_mask} - - -class BertForNextSentencePrediction(BertPreTrainedModel): - # TBD: Push in future commit - pass - - -class BertForSequenceClassification(BertPreTrainedModel): - """Bert Model transformer with a sequence classification/regression head. - - This head is just a linear layer on top of the pooled output. Used for, - e.g., GLUE tasks. - """ - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def from_composer( - cls, - pretrained_checkpoint, - state_dict=None, - cache_dir=None, - from_tf=False, - config=None, - *inputs, - **kwargs, - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - # Labels for computing the sequence classification/regression loss. - # Indices should be in `[0, ..., config.num_labels - 1]`. - # If `config.num_labels == 1` a regression loss is computed - # (mean-square loss). If `config.num_labels > 1` a classification loss - # is computed (cross-entropy). - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - # Compute loss - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = nn.MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - ) - - -class BertForMultipleChoice(BertPreTrainedModel): - """ - Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """ - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - - # In multiple choice tasks, all choices are submitted in a batch, and - # we compute a logit for each option independently. The logits are then - # normalized in the forward pass to get a probability distribution over - # the choices. - self.classifier = nn.Linear(config.hidden_size, 1) - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def from_composer( - cls, - pretrained_checkpoint, - state_dict=None, - cache_dir=None, - from_tf=False, - config=None, - *inputs, - **kwargs, - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., - num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See - `input_ids` above) - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - - input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None - attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None - position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None - else None - ) - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=None, - attentions=None, - ) - - -class BertForTokenClassification(BertPreTrainedModel): - # TBD: Push in future commit - pass - - -class BertForQuestionAnswering(BertPreTrainedModel): - """Bert Model with a span classification head. - - This is used for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden states' output to compute `span start logits` - and `span end logits`). - """ - - # TBD: Push in future commit - - -################### -# FlexBert Heads -################### - - -class FlexBertPredictionHead(nn.Module): - def __init__(self, config: FlexBertConfig): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_pred_bias) - self.act = get_act_fn(config.head_pred_act) if config.head_pred_act else nn.Identity() - self.norm = ( - get_norm_layer(config, compiled_norm=config.compile_model) if config.head_pred_norm else nn.Identity() - ) - - def _init_weights(self, reset_params: bool = False): - if reset_params: - self.norm.reset_parameters() - init_weights(self.config, self.dense, layer_dim=self.config.hidden_size, type_of_module=ModuleType.in_module) - - def reset_parameters(self): - self._init_weights(reset_params=True) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.norm(self.act(self.dense(hidden_states))) - - -class FlexBertPoolingHead(nn.Module): - def __init__(self, config: FlexBertConfig): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_class_bias) - self.act = get_act_fn(config.head_class_act) if config.head_class_act else nn.Identity() - self.norm = get_norm_layer(config) if config.head_class_norm else nn.Identity() - self.drop = torch.nn.Dropout(config.head_class_dropout) if config.head_class_dropout > 0 else nn.Identity() - self.pooling_type = config.pooling_type - - def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: - if pool: - if self.pooling_type == "cls": - output = hidden_states[:, 0] - elif self.pooling_type == "mean": - output = hidden_states.mean(dim=1) - elif self.pooling_type == "max": - output = hidden_states.max(dim=1)[0] - else: - output = hidden_states - - return self.drop(self.norm(self.act(self.dense(output)))) - - def _init_weights(self, reset_params: bool = False): - init_weights(self.config, self.dense, self.config.hidden_size, type_of_module=ModuleType.out_module) - if reset_params and hasattr(self.norm, "reset_parameters"): - self.norm.reset_parameters() - - def reset_parameters(self): - self._init_weights(reset_params=True) - - -################### -# FlexBert Models -################### - - -@dataclass -class MaskedLMOutput(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - indices: Optional[torch.LongTensor] = None - cu_seqlens: Optional[torch.LongTensor] = None - max_seqlen: Optional[int] = None - batch_size: Optional[int] = None - seq_len: Optional[int] = None - labels: Optional[torch.LongTensor] = None - - -@dataclass -class MaskedLMOutputZLoss(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - ce_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Cross entropy loss. - z_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Z loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - indices (`torch.LongTensor` of shape `(batch_size,)`): - Indices of the tokens to be masked. - """ - - loss: Optional[torch.FloatTensor] = None - ce_loss: Optional[torch.FloatTensor] = None - z_loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - indices: Optional[torch.LongTensor] = None - cu_seqlens: Optional[torch.LongTensor] = None - max_seqlen: Optional[int] = None - batch_size: Optional[int] = None - seq_len: Optional[int] = None - labels: Optional[torch.LongTensor] = None - - -class FlexBertPreTrainedModel(BertPreTrainedModel): - """ - An abstract class to handle custom weights initialization of modules - """ - - def _init_module_weights(self, module: nn.Module): - """ - Custom weight init of modules using src.bert_layers.initialization.init_weights - Currently only supports init of embedding modules - """ - assert isinstance(module, nn.Module) - if isinstance(module, nn.Embedding): - init_weights(self.config, module, type_of_module=ModuleType.emb) - else: - raise NotImplementedError("Custom weight init for the given module is not supported") - - -class FlexBertModel(FlexBertPreTrainedModel): - """Overall BERT model. - - Args: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - model = BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__(self, config: FlexBertConfig): - super().__init__(config) - self.embeddings = get_embedding_layer(config) - self.encoder = get_encoder_layer(config) - if config.final_norm: - # if we use prenorm attention we need to add a final norm - self.final_norm = get_norm_layer(config) - else: - self.final_norm = None - self.unpad_embeddings = config.unpad_embeddings - - def post_init(self): - self._init_weights(reset_params=False) - self._backward_compatibility_gradient_checkpointing() - - def get_input_embeddings(self): - return self.embeddings.tok_embeddings - - def set_input_embeddings(self, value): - self.embeddings.tok_embeddings = value - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - **kwargs, - ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - - embedding_output = self.embeddings(input_ids, position_ids) - - encoder_outputs = self.encoder( - hidden_states=embedding_output, - attention_mask=attention_mask, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - if self.final_norm is not None: - encoder_outputs = self.final_norm(encoder_outputs) - return encoder_outputs - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - self._init_module_weights(module) - else: - assert isinstance(reset_params, bool) - self.embeddings._init_weights(reset_params=reset_params) - self.encoder._init_weights(reset_params=reset_params) - - if reset_params and self.config.final_norm: - self.final_norm.reset_parameters() - - def reset_parameters(self): - self._init_weights(reset_params=True) - - def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: - """Returns the number of parameters in the model. - - Args: - count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. - trainable: only count trainable parameters. - """ - params = sum([_count_parameters(layer, trainable) for layer in self.encoder.layers]) - if count_embeddings: - params += _count_parameters(self.embeddings, trainable) - if hasattr(self.embeddings, "position_embeddings"): - params -= _count_parameters(self.embeddings.position_embeddings, trainable) - return params - - -class FlexBertForMaskedLM(FlexBertPreTrainedModel): - def __init__(self, config: FlexBertConfig): - super().__init__(config) - self.bert = FlexBertModel(config) - self.head = FlexBertPredictionHead(config) - - if config.tie_word_embeddings: - decoder_weights = self.bert.embeddings.tok_embeddings.weight - else: - decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight - self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias) - self.decoder.weight = decoder_weights - - self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config) - self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy" - self.return_z_loss = config.loss_kwargs.get("return_z_loss", False) - self.unpad_embeddings = config.unpad_embeddings - self.pad_logits = config.pad_logits - self.compile_model = config.compile_model - self.masked_prediction = config.masked_prediction - - # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _get_aux_loss(self) -> Optional[torch.Tensor]: - """Collect auxiliary losses from all MoE layers in the model.""" - aux_loss = None - - # Traverse all modules to find FlexBertGLUMoE layers - for module in self.modules(): - if hasattr(module, 'aux_loss') and module.aux_loss is not None: - if aux_loss is None: - aux_loss = module.aux_loss - else: - aux_loss = aux_loss + module.aux_loss - - return aux_loss - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - self._init_module_weights(module) - else: - assert isinstance(reset_params, bool) - self.bert._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - - # Output weights. - if not self.config.tie_word_embeddings: - init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out) - - @classmethod - def from_composer( - cls, - pretrained_checkpoint, - state_dict=None, - cache_dir=None, - from_tf=False, - config=None, - *inputs, - **kwargs, - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("FlexBERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def get_output_embeddings(self): - return self.decoder - - def set_output_embeddings(self, new_embeddings): - self.decoder = new_embeddings - - @torch.no_grad() - def unpad_inputs( - self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor - ): - return unpad_input(input_ids, attention_mask, position_ids, labels) - - @torch.no_grad() - def pad_inputs( - self, - inputs: torch.Tensor, - indices: torch.Tensor, - batch_size: int, - seqlen: int, - labels: Optional[torch.Tensor] = None, - ignore_index: int = -100, - ): - return pad_input( - inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index - ) - - @torch.compile(dynamic=True) - def compiled_head(self, output: torch.Tensor) -> torch.Tensor: - return self.decoder(self.head(output)) - - def forward( - self, - input_ids: Optional[torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, - indices: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - batch_size: Optional[int] = None, - seq_len: Optional[int] = None, - **kwargs, - ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: - # labels should be a `torch.LongTensor` of shape - # `(batch_size, sequence_length)`. These are used for computing the - # masked language modeling loss. - # - # Indices should be in `[-100, 0, ..., config.vocab_size]` (see - # `input_ids` docstring) Tokens with indices set to `-100` are ignored - # (masked), the loss is only computed for the tokens with labels in `[0, - # ..., config.vocab_size]` - # - # Prediction scores are only computed for masked tokens and the (bs, - # seqlen) dimensions are flattened - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None): - batch_size, seq_len = input_ids.shape[:2] - input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs( - input_ids, attention_mask, position_ids, labels - ) - - output = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - if self.masked_prediction and labels is not None: - # flatten labels and output first - labels = labels.view(-1) - output = output.view(labels.shape[0], -1) - - # then filter out the non-masked tokens - mask_tokens = labels != self.loss_fn.ignore_index - output = output[mask_tokens] - labels = labels[mask_tokens] - - if self.compile_model: - logits = self.compiled_head(output) - else: - logits = self.decoder(self.head(output)) - - loss = None - if labels is not None: - if not self.masked_prediction: - labels = labels.view(-1) - logits = logits.view(labels.shape[0], -1) - - if self.return_z_loss: - loss, z_loss = self.loss_fn(logits, labels) - if self.pad_logits: - return MaskedLMOutputZLoss( - loss=loss, - ce_loss=loss.detach().clone() - z_loss, - z_loss=z_loss, - logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0], - hidden_states=None, - attentions=None, - ) - else: - return MaskedLMOutputZLoss( - loss=loss, - ce_loss=loss.detach().clone() - z_loss, - z_loss=z_loss, - logits=logits, - hidden_states=None, - attentions=None, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - batch_size=batch_size, - seq_len=seq_len, - labels=labels, - ) - else: - loss = self.loss_fn(logits, labels) - - if self.pad_logits: - return MaskedLMOutput( - loss=loss, - logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0], - hidden_states=None, - attentions=None, - ) - else: - return MaskedLMOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - batch_size=batch_size, - seq_len=seq_len, - labels=labels, - ) - - def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs): - input_shape = input_ids.shape - effective_batch_size = input_shape[0] - - # add a dummy token - if self.config.pad_token_id is None: - raise ValueError("The PAD token should be defined for generation") - - attention_mask = torch.cat( - [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], - dim=-1, - ) - dummy_token = torch.full( - (effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device, - ) - input_ids = torch.cat([input_ids, dummy_token], dim=1) - - return {"input_ids": input_ids, "attention_mask": attention_mask} - - def get_number_parameters( - self, count_embeddings: bool = True, count_decoder: bool = False, trainable: bool = True - ) -> int: - """Returns the number of parameters in the model. - - Args: - count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. - count_decoder: count the parameters in the decoder layer if weights are not tied. - trainable: only count trainable parameters. - """ - params = self.bert.get_number_parameters(count_embeddings, trainable) - params += _count_parameters(self.head, trainable) - if count_decoder and not self.config.tie_word_embeddings: - params += _count_parameters(self.decoder, trainable) - return params - - -class FlexBertForSequenceClassification(FlexBertPreTrainedModel): - """Bert Model transformer with a sequence classification/regression head. - - This head is just a linear layer on top of the pooled output. Used for, - e.g., GLUE tasks. - """ - - def __init__(self, config: FlexBertConfig): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = FlexBertModel(config) - self.head = FlexBertPoolingHead(config) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _get_aux_loss(self) -> Optional[torch.Tensor]: - """Collect auxiliary losses from all MoE layers in the model.""" - aux_loss = None - - # Traverse all modules to find FlexBertGLUMoE layers - for module in self.modules(): - if hasattr(module, 'aux_loss') and module.aux_loss is not None: - if aux_loss is None: - aux_loss = module.aux_loss - else: - aux_loss = aux_loss + module.aux_loss - - return aux_loss - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - self._init_module_weights(module) - else: - assert isinstance(reset_params, bool) - self.bert._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out) - - @classmethod - def from_composer( - cls, - pretrained_checkpoint, - state_dict=None, - cache_dir=None, - from_tf=False, - config=None, - *inputs, - **kwargs, - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - # Labels for computing the sequence classification/regression loss. - # Indices should be in `[0, ..., config.num_labels - 1]`. - # If `config.num_labels == 1` a regression loss is computed - # (mean-square loss). If `config.num_labels > 1` a classification loss - # is computed (cross-entropy). - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - output = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - pooled_output = self.head(output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - # Compute loss - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = nn.MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - # Add MoE auxiliary losses if present - aux_loss = self._get_aux_loss() - if aux_loss is not None: - loss = loss + aux_loss - - if not return_dict: - output = (logits,) + output - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - ) - - def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: - """Returns the number of parameters in the model. - - Args: - count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. - trainable: only count trainable parameters. - """ - params = self.bert.get_number_parameters(count_embeddings, trainable) - params += _count_parameters(self.head, trainable) - params += _count_parameters(self.classifier, trainable) - return params - - -class FlexBertForMultipleChoice(FlexBertPreTrainedModel): - """ - Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """ - - def __init__(self, config: FlexBertConfig): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = FlexBertModel(config) - self.head = FlexBertPoolingHead(config) - - # In multiple choice tasks, all choices are submitted in a batch, and - # we compute a logit for each option independently. The logits are then - # normalized in the forward pass to get a probability distribution over - # the choices. - self.classifier = nn.Linear(config.hidden_size, 1) - - # Initialize weights and apply final processing - self._init_weights(reset_params=False) - - def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): - assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" - if module: - self._init_module_weights(module) - else: - assert isinstance(reset_params, bool) - self.bert._init_weights(reset_params=reset_params) - self.head._init_weights(reset_params=reset_params) - init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out) - - @classmethod - def from_composer( - cls, - pretrained_checkpoint, - state_dict=None, - cache_dir=None, - from_tf=False, - config=None, - *inputs, - **kwargs, - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - # Labels for computing the sequence classification/regression loss. - # Indices should be in `[0, ..., config.num_labels - 1]`. - # If `config.num_labels == 1` a regression loss is computed - # (mean-square loss). If `config.num_labels > 1` a classification loss - # is computed (cross-entropy). - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[1] - - input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None - attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - - output = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - pooled_output = self.head(output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits,) + output - return ((loss,) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=None, - attentions=None, - ) - - def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int: - """Returns the number of parameters in the model. - - Args: - count_embeddings: count the parameters in the embeddings layer, excluding position embeddings. - trainable: only count trainable parameters. - """ - params = self.bert.get_number_parameters(count_embeddings, trainable) - params += _count_parameters(self.head, trainable) - params += _count_parameters(self.classifier, trainable) - return params - - -def init_model_from_pretrained( - pretrained_model: FlexBertModel, - new_model: FlexBertModel, - mode: Union[str, TileMode] = TileMode.tile_weights_from_middle, -): - """ - Initialize the new model from the pretrained model. - - This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`. - The new model must have the same or more layers and the same or larger dimensions than the pretrained model. - - Args: - pretrained_model (FlexBertModel): The smaller, pre-trained model - new_model (FlexBertModel): The larger model to be initialized - mode (Union[str, TileMode]): The Phi-style weight tiling mode to use - - This function assumes that the new_model has more layers and a larger hidden size - than the pretrained_model, but the same vocabulary size. - """ - - # Tile embeddings - assert isinstance( - new_model.embeddings, type(pretrained_model.embeddings) - ), f"Pretrained and new_model layers must be the same type, got {type(new_model.embeddings)} and {type(pretrained_model.embeddings)}" - assert isinstance( - new_model.embeddings, - (FlexBertAbsoluteEmbeddings, FlexBertSansPositionEmbeddings, FlexBertCompiledSansPositionEmbeddings), - ), f"Unsupported embedding layer type: {type(new_model.embeddings)}" - - tile_embedding(pretrained_model.embeddings.tok_embeddings, new_model.embeddings.tok_embeddings, mode=mode) - if isinstance(pretrained_model.embeddings, FlexBertAbsoluteEmbeddings): - tile_embedding(pretrained_model.embeddings.pos_embeddings, new_model.embeddings.pos_embeddings, mode=mode) - - if hasattr(pretrained_model.embeddings, "norm"): - tile_norm(pretrained_model.embeddings.norm, new_model.embeddings.norm, mode=mode) - - # Tile encoder layers - assert isinstance( - pretrained_model.encoder, (FlexBertUnpadEncoder, FlexBertPaddedEncoder) - ), f"Unsupported encoder layer type: {type(pretrained_model.encoder)}" - assert isinstance( - new_model.encoder, type(pretrained_model.encoder) - ), f"Pretrained and new_model encoder layers must be the same type, got {type(new_model.encoder)} and {type(pretrained_model.encoder)}" - - # Calculate the layer mapping - pretrained_layers = len(pretrained_model.encoder.layers) - new_layers = len(new_model.encoder.layers) - layer_mapping = [round(i * pretrained_layers / new_layers) for i in range(new_layers)] - - # Initialize layers - for new_model_idx, pretrained_idx in enumerate(layer_mapping): - new_model_layer = new_model.encoder.layers[new_model_idx] - pretrained_layer = pretrained_model.encoder.layers[pretrained_idx] - - # first tile the PreNorm/PostNorm layers - assert isinstance( - new_model_layer, type(pretrained_layer) - ), f"Pretrained and new_model prenorm/postnorm layers must be the same type, got {type(new_model_layer)} and {type(pretrained_layer)}" - assert isinstance( - new_model_layer, - ( - FlexBertUnpadPreNormLayer, - FlexBertCompileUnpadPreNormLayer, - FlexBertUnpadParallelPreNormLayer, - FlexBertUnpadPostNormLayer, - FlexBertPaddedPreNormLayer, - FlexBertPaddedParallelPreNormLayer, - FlexBertPaddedPostNormLayer, - ), - ), f"Unsupported prenorm/postnorm layer type: {type(new_model_layer)}" - - # First tile the normalization layers - if hasattr(pretrained_layer, "attn_norm"): - tile_norm(pretrained_layer.attn_norm, new_model_layer.attn_norm, mode=mode) - if hasattr(pretrained_layer, "norm"): - tile_norm(pretrained_layer.norm, new_model_layer.norm, mode=mode) - if hasattr(pretrained_layer, "mlp_norm"): - tile_norm(pretrained_layer.mlp_norm, new_model_layer.mlp_norm, mode=mode) - - # Then tile the attention & mlp layers - assert isinstance( - new_model_layer.attn, type(pretrained_layer.attn) - ), f"Pretrained and new_model attention layers must be the same type, got {type(new_model_layer.attn)} and {type(pretrained_layer.attn)}" - - # first try the parallel attention layers - if isinstance(pretrained_layer, (FlexBertUnpadParallelPreNormLayer, FlexBertPaddedParallelPreNormLayer)): - assert isinstance( - pretrained_layer.attn, - ( - FlexBertUnpadParallelAttention, - FlexBertPaddedParallelAttention, - FlexBertUnpadRopeParallelAttention, - FlexBertPaddedRopeParallelAttention, - ), - ), f"Parallel prenorm layer must have parallel attention layer: {type(pretrained_layer.attn)}" - if not isinstance(pretrained_layer.mlp, (FlexBertParallelGLU)): - raise ValueError(f"Parallel prenorm layer must have parallel MLP layer: {type(pretrained_layer.mlp)}") - tile_linear( - pretrained_layer.Wqkvff, - new_model_layer.Wqkvff, - linear_type=TileLinear.wqkvff, - mode=mode, - pretrained_attn_size=pretrained_layer.attn_size, - pretrained_mlp_size=pretrained_layer.mlp_size, - new_attn_size=new_model_layer.attn_size, - new_mlp_size=new_model_layer.mlp_size, - wqkvff_is_glu=True, - ) - - # then try the fused attention layers - elif isinstance( - pretrained_layer.attn, - ( - FlexBertUnpadAttention, - FlexBertPaddedAttention, - FlexBertUnpadRopeAttention, - FlexBertPaddedRopeAttention, - ), - ): - tile_linear(pretrained_layer.attn.Wqkv, new_model_layer.attn.Wqkv, linear_type=TileLinear.wqkv, mode=mode) - else: - raise ValueError(f"Unsupported attention layer type: {type(pretrained_layer.attn)}") - - # finally, tile the attention output layer - tile_linear(pretrained_layer.attn.Wo, new_model_layer.attn.Wo, linear_type=TileLinear.default, mode=mode) - - # tile the mlp layer if the model is not using parallel attention layers - if not isinstance(pretrained_layer.mlp, (FlexBertMLP, FlexBertGLU, FlexBertParallelGLU)): - raise ValueError(f"Unsupported MLP layer type: {type(pretrained_layer.mlp)}") - assert isinstance( - new_model_layer.mlp, type(pretrained_layer.mlp) - ), f"Pretrained and new_model mlp layers must be the same type, got {type(new_model_layer.mlp)} and {type(pretrained_layer.mlp)}" - - # already tiled the parallel glu layer if it exists, so only need to handle mlp & glu Wi - if isinstance(pretrained_layer.mlp, FlexBertGLU): - tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.glu, mode=mode) - elif isinstance(pretrained_layer.mlp, FlexBertMLP): - tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.default, mode=mode) - # tile the output for both ParallelGLU and MLP/GLU - tile_linear(pretrained_layer.mlp.Wo, new_model_layer.mlp.Wo, linear_type=TileLinear.default, mode=mode) - - -def init_mlm_model_from_pretrained( - config: FlexBertConfig, - pretrained_model: FlexBertForMaskedLM, - new_model: FlexBertForMaskedLM, - mode: Union[str, TileMode] = TileMode.tile_weights_from_middle, -): - """ - Initialize the new model from the pretrained model. - - This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`. - The new model must have the same or more layers and the same or larger dimensions than the pretrained model. - - Args: - config (FlexBertConfig): The configuration of the new_model - pretrained_model (FlexBertForMaskedLM): The smaller, pre-trained model - new_model (FlexBertForMaskedLM): The larger model to be initialized from the pretrained model - mode (Union[str, TileMode]): The Phi-style weight tiling mode to use - - This function assumes that the new_model has more layers and a larger hidden size - than the pretrained_model, but the same vocabulary size. - """ - init_model_from_pretrained(pretrained_model.bert, new_model.bert, mode=mode) - - # TODO: uncomment this when the repo is turned into a pip installable package - # if not isinstance(pretrained_model.head, FlexBertPredictionHead): - # raise ValueError(f"Pretrained model must have a prediction head: {type(pretrained_model.head)}") - # if not isinstance(new_model.head, FlexBertPredictionHead): - # raise ValueError(f"New model must have a prediction head: {type(new_model.head)}") - - # tile the prediction head - tile_linear(pretrained_model.head.dense, new_model.head.dense, linear_type=TileLinear.default, mode=mode) - tile_norm(pretrained_model.head.norm, new_model.head.norm, mode=mode) - - # setup weight tying - if config.tie_word_embeddings: - new_model.decoder.weight = new_model.bert.embeddings.tok_embeddings.weight - tile_linear( - pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True - ) - else: - tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode) diff --git a/src/bert_layers/configuration_bert.py b/src/bert_layers/configuration_bert.py index 6fbdeb53..b9f1e6ce 100644 --- a/src/bert_layers/configuration_bert.py +++ b/src/bert_layers/configuration_bert.py @@ -58,7 +58,7 @@ def __init__( loss_kwargs: dict = {}, mlp_dropout_prob: float = 0.0, mlp_in_bias: bool = False, - mlp_layer: str = "mlp", + mlp_layer: str = "glu_moe", mlp_out_bias: bool = False, norm_kwargs: dict = {}, normalization: str = "rmsnorm", @@ -97,6 +97,13 @@ def __init__( pad_logits: bool = False, compile_model: bool = False, masked_prediction: bool = False, + moe_num_experts: int = 8, + moe_top_k: int = 2, + moe_use_noisy_top_k: bool = True, + moe_capacity_factor: float = 1.25, + moe_compute_aux_loss: bool = True, + moe_load_balance_loss_weight: float = 0.01, + moe_router_z_loss_weight: float = 0.001, **kwargs, ): """ @@ -156,6 +163,13 @@ def __init__( pad_logits (bool): Pad logits after the calculating the loss. compile_model (bool): Compile the subset of the model which can be compiled. masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers + moe_num_experts (int): Number of experts for Mixture of Experts layers. + moe_top_k (int): Number of top experts to select for each token in MoE. + moe_use_noisy_top_k (bool): Use noisy top-k gating for exploration during training. + moe_capacity_factor (float): Capacity factor for expert assignment in MoE. + moe_compute_aux_loss (bool): Whether to compute and add auxiliary losses for MoE. + moe_load_balance_loss_weight (float): Weight for the load balancing auxiliary loss. + moe_router_z_loss_weight (float): Weight for the router z-loss auxiliary loss. **kwargs: Additional keyword arguments. """ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) @@ -213,6 +227,13 @@ def __init__( self.pad_logits = pad_logits self.compile_model = compile_model self.masked_prediction = masked_prediction + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_use_noisy_top_k = moe_use_noisy_top_k + self.moe_capacity_factor = moe_capacity_factor + self.moe_compute_aux_loss = moe_compute_aux_loss + self.moe_load_balance_loss_weight = moe_load_balance_loss_weight + self.moe_router_z_loss_weight = moe_router_z_loss_weight if loss_kwargs.get("return_z_loss", False): if loss_function != "fa_cross_entropy": diff --git a/src/bert_layers/loss.py b/src/bert_layers/loss.py index 8b0007cf..3f210a4c 100644 --- a/src/bert_layers/loss.py +++ b/src/bert_layers/loss.py @@ -2,7 +2,9 @@ # License: Apache-2.0 import inspect +import torch import torch.nn as nn +import torch.nn.functional as F from .configuration_bert import FlexBertConfig try: @@ -20,6 +22,79 @@ LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss +class MoELoadBalancingLoss(nn.Module): + """Computes Switch Transformer auxiliary loss for load balancing. + + Reference: https://arxiv.org/abs/2101.03961 (equations 4-6, page 7) + + This loss encourages balanced token allocation across experts to avoid + scenarios where some experts are overloaded while others are underutilized. + """ + + def __init__(self, num_experts: int, top_k: int = 2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + + def forward( + self, + router_logits: torch.Tensor, + expert_indices: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + router_logits: Router logits [batch_size, seq_len, num_experts] + expert_indices: Top-k expert indices [batch_size, seq_len, top_k] + + Returns: + load_balance_loss: Scalar loss value + """ + # Compute expert probabilities + expert_probs = F.softmax(router_logits, dim=-1) # [B, C, n_exp] + + # Equation (5): compute ratio of tokens allocated to each expert + with torch.no_grad(): + one_hot_indices = F.one_hot(expert_indices, num_classes=self.num_experts) # [B, C, K, n_exp] + one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, C, n_exp] + tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1)) # [n_exp] + + # Equation (6): compute ratio of router probability allocated to each expert + prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1)) # [n_exp] + + # Equation (4): scaled dot product between prob / token allocation vectors + load_balance_loss = self.num_experts * torch.sum(prob_per_expert * tokens_per_expert) + + return load_balance_loss + + +class MoERouterZLoss(nn.Module): + """Computes router z-loss for MoE models. + + Reference: https://arxiv.org/abs/2202.08906 (equation 5, page 7) + + This loss constrains the size of router logits to avoid numerical instability + during training. Large logits can lead to round-off errors in the softmax computation, + even in float32 precision. + """ + + def forward(self, router_logits: torch.Tensor) -> torch.Tensor: + """ + Args: + router_logits: Router logits [batch_size, seq_len, num_experts] + + Returns: + router_z_loss: Scalar loss value + """ + # Numerically stable computation: logsumexp is equivalent to log(sum(exp(x))) + # This avoids overflow issues from directly exponentiating large logits + router_z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, C] + + # Average over all tokens + router_z_loss = torch.mean(router_z_loss) + + return router_z_loss + + def get_loss_fn(config: FlexBertConfig) -> nn.Module: try: loss_class = LOSS2CLS[config.loss_function] diff --git a/src/bert_layers/mlp.py b/src/bert_layers/mlp.py index 349d559b..f344008d 100644 --- a/src/bert_layers/mlp.py +++ b/src/bert_layers/mlp.py @@ -15,11 +15,13 @@ import torch import torch.nn as nn +import torch.nn.functional as F from .configuration_bert import FlexBertConfig from .activation import get_act_fn from .normalization import get_norm_layer from .initialization import ModuleType, init_weights +from .loss import MoELoadBalancingLoss, MoERouterZLoss class BertResidualGLU(nn.Module): @@ -190,10 +192,197 @@ def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: return self.Wo(self.drop(self.act(input) * gate)) +class Router(nn.Module): + """Top-K router for selecting experts.""" + + def __init__( + self, + d: int, + n_exp: int, + top_k: int = 2, + use_noisy_top_k: bool = True, + capacity_factor: float = 1.25, + ): + super().__init__() + self.d = d + self.n_exp = n_exp + self.top_k = top_k + self.use_noisy_top_k = use_noisy_top_k + self.capacity_factor = capacity_factor + + # Router weights to compute logits for each expert + self.gate = nn.Linear(d, n_exp, bias=False) + + # Noise parameters for noisy top-k gating + if use_noisy_top_k: + self.w_noise = nn.Linear(d, n_exp, bias=False) + + def forward(self, x: torch.Tensor): + """ + Args: + x: [batch_size, seq_len, d] + Returns: + exp_weight: Expert weights [batch_size * seq_len, top_k] + exp_mask: Expert mask [batch_size * seq_len, n_exp, exp_capacity] + exp_batches: Token assignments [n_exp, exp_capacity, d] + """ + B, C, d = x.size() + num_tokens = B * C + x_flat = x.view(num_tokens, d) + + # Compute router logits + logits = self.gate(x_flat) # [num_tokens, n_exp] + + # Add noise for exploration (optional) + if self.use_noisy_top_k and self.training: + noise_stddev = F.softplus(self.w_noise(x_flat)) + noise = torch.randn_like(logits) * noise_stddev + logits = logits + noise + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) + top_k_gates = F.softmax(top_k_logits, dim=-1) # [num_tokens, top_k] + + # Compute expert capacity + exp_capacity = int((num_tokens * self.top_k * self.capacity_factor) / self.n_exp) + + # Create expert assignment mask and batches + exp_mask = torch.zeros(num_tokens, self.n_exp, exp_capacity, device=x.device) + exp_batches = torch.zeros(self.n_exp, exp_capacity, d, device=x.device) + + # Count tokens assigned to each expert + expert_counts = torch.zeros(self.n_exp, dtype=torch.long, device=x.device) + + # Assign tokens to experts + for token_idx in range(num_tokens): + for k_idx in range(self.top_k): + expert_idx = top_k_indices[token_idx, k_idx] + if expert_counts[expert_idx] < exp_capacity: + pos = expert_counts[expert_idx] + exp_mask[token_idx, expert_idx, pos] = top_k_gates[token_idx, k_idx] + exp_batches[expert_idx, pos] = x_flat[token_idx] + expert_counts[expert_idx] += 1 + + return top_k_gates, exp_mask, exp_batches + + + + + +class FlexBertGLUMoE(FlexBertMLPBase): + """Mixture of Experts with GLU activation for FlexBERT.""" + + def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): + super().__init__(config=config, layer_id=layer_id) + + self.n_exp = getattr(config, 'moe_num_experts', 4) + self.top_k = getattr(config, 'moe_top_k', 2) + self.use_noisy_top_k = getattr(config, 'moe_use_noisy_top_k', True) + self.capacity_factor = getattr(config, 'moe_capacity_factor', 1.25) + self.compute_aux_loss = getattr(config, 'moe_compute_aux_loss', True) + self.load_balance_loss_weight = getattr(config, 'moe_load_balance_loss_weight', 0.01) + self.router_z_loss_weight = getattr(config, 'moe_router_z_loss_weight', 0.001) + + self.router = Router( + d=config.hidden_size, + n_exp=self.n_exp, + top_k=self.top_k, + use_noisy_top_k=self.use_noisy_top_k, + capacity_factor=self.capacity_factor, + ) + + # GLU experts (each projects to 2x intermediate size) + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=config.mlp_in_bias), + nn.Identity(), # Placeholder for chunking + activation + nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity(), + nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias), + ) + for _ in range(self.n_exp) + ]) + self.act = get_act_fn(config.hidden_act) + + # Initialize auxiliary loss modules + if self.compute_aux_loss: + self.load_balance_loss = MoELoadBalancingLoss(num_experts=self.n_exp, top_k=self.top_k) + self.router_z_loss = MoERouterZLoss() + + def _init_weights(self, reset_params: bool = False): + init_weights( + self.config, + self.router.gate, + layer_dim=self.config.hidden_size, + layer_id=self.layer_id, + type_of_module=ModuleType.in_module, + ) + + for expert in self.experts: + for i, module in enumerate(expert): + if isinstance(module, nn.Linear): + init_weights( + self.config, + module, + layer_dim=self.config.hidden_size if i == 0 else self.config.intermediate_size, + layer_id=self.layer_id, + type_of_module=ModuleType.in_module if i == 0 else ModuleType.out_module, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + original_shape = hidden_states.shape + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + B, C, d = hidden_states.size() + num_tokens = B * C + x_flat = hidden_states.view(num_tokens, d) + + # Compute router logits for auxiliary loss calculation + router_logits = self.router.gate(x_flat) # [num_tokens, n_exp] + + exp_weight, exp_mask, exp_batches = self.router(hidden_states) + + # Extract top-k indices from router for load balancing loss + _, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) # [num_tokens, top_k] + + # Apply GLU experts + exp_out = torch.zeros_like(exp_batches) + for i, expert in enumerate(self.experts): + x = expert[0](exp_batches[i]) # Linear projection + input, gate = x.chunk(2, dim=-1) # Split for GLU + x = self.act(input) * gate # GLU activation + x = expert[2](x) # Dropout + exp_out[i] = expert[3](x) # Output projection + + exp_weight_flat = exp_mask.view(num_tokens, -1) + exp_out_flat = exp_out.view(-1, d) + output = torch.matmul(exp_weight_flat, exp_out_flat) + + # Compute auxiliary losses + self.aux_loss = None + if self.compute_aux_loss: + # Reshape for loss computation + router_logits_reshaped = router_logits.view(B, C, -1) + top_k_indices_reshaped = top_k_indices.view(B, C, -1) + + # Compute load balancing loss + lb_loss = self.load_balance_loss(router_logits_reshaped, top_k_indices_reshaped) + + # Compute router z-loss + z_loss = self.router_z_loss(router_logits_reshaped) + + # Combine auxiliary losses with weights + self.aux_loss = self.load_balance_loss_weight * lb_loss + self.router_z_loss_weight * z_loss + + return output.view(*original_shape) + + +# Update the MLP registry MLP2CLS = { "mlp": FlexBertMLP, "glu": FlexBertGLU, "parallel_glu": FlexBertParallelGLU, + "glu_moe": FlexBertGLUMoE, } @@ -212,3 +401,5 @@ def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> Fle ) else: raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") + + diff --git a/src/bert_layers/model.py b/src/bert_layers/model.py index fd05e507..aeac1b01 100644 --- a/src/bert_layers/model.py +++ b/src/bert_layers/model.py @@ -407,6 +407,11 @@ def forward( loss_fct = nn.CrossEntropyLoss() masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx]) + + # Add MoE auxiliary losses if present + aux_loss = self._get_aux_loss() + if aux_loss is not None: + loss = loss + aux_loss assert input_ids is not None, "Coding error; please open an issue" batch, seqlen = input_ids.shape[:2] @@ -1021,6 +1026,20 @@ def __init__(self, config: FlexBertConfig): # Initialize weights and apply final processing self._init_weights(reset_params=False) + + def _get_aux_loss(self) -> Optional[torch.Tensor]: + """Collect auxiliary losses from all MoE layers in the model.""" + aux_loss = None + + # Traverse all modules to find FlexBertGLUMoE layers + for module in self.modules(): + if hasattr(module, 'aux_loss') and module.aux_loss is not None: + if aux_loss is None: + aux_loss = module.aux_loss + else: + aux_loss = aux_loss + module.aux_loss + + return aux_loss def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" @@ -1265,6 +1284,20 @@ def __init__(self, config: FlexBertConfig): # Initialize weights and apply final processing self._init_weights(reset_params=False) + def _get_aux_loss(self) -> Optional[torch.Tensor]: + """Collect auxiliary losses from all MoE layers in the model.""" + aux_loss = None + + # Traverse all modules to find FlexBertGLUMoE layers + for module in self.modules(): + if hasattr(module, 'aux_loss') and module.aux_loss is not None: + if aux_loss is None: + aux_loss = module.aux_loss + else: + aux_loss = aux_loss + module.aux_loss + + return aux_loss + def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None): assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified" if module: @@ -1352,6 +1385,11 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) + + # Add MoE auxiliary losses if present + aux_loss = self._get_aux_loss() + if aux_loss is not None: + loss = loss + aux_loss if not return_dict: output = (logits,) + output From 9f91b576f9d0ebf881306db6957f70eb256950ea Mon Sep 17 00:00:00 2001 From: heyo66 Date: Sun, 14 Dec 2025 17:04:08 +0100 Subject: [PATCH 3/4] ... --- .../moebert-rope-base-c4-realnewslike.yaml | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 yamls/main/moebert-rope-base-c4-realnewslike.yaml diff --git a/yamls/main/moebert-rope-base-c4-realnewslike.yaml b/yamls/main/moebert-rope-base-c4-realnewslike.yaml new file mode 100644 index 00000000..0eb6a81b --- /dev/null +++ b/yamls/main/moebert-rope-base-c4-realnewslike.yaml @@ -0,0 +1,174 @@ +# MoEBERT with RoPE - Base configuration +# Based on FlexBERT-RoPE-Base with Mixture of Experts enabled +# Using C4 RealNewsLike dataset for testing + +# Data configuration - C4 RealNewsLike +data_local: ./c4_realnewslike +data_remote: + +max_seq_len: 512 +tokenizer_name: bert-base-uncased +mlm_probability: 0.3 + +# Run Name +run_name: moebert-rope-base-13799838 + +# Model +model: + name: flex_bert + recompute_metric_loss: false + pretrained_model_name: ${tokenizer_name} + tokenizer_name: ${tokenizer_name} + model_config: + # Base architecture (ModernBERT-base) + num_attention_heads: 12 + num_hidden_layers: 12 + attention_layer: rope + attention_probs_dropout_prob: 0.0 + attn_out_bias: false + attn_out_dropout_prob: 0.0 + attn_qkv_bias: false + bert_layer: prenorm + embed_dropout_prob: 0.0 + embed_norm: false + final_norm: true + embedding_layer: sans_pos + loss_function: fa_cross_entropy + loss_kwargs: + reduction: mean + mlp_dropout_prob: 0.0 + mlp_in_bias: false + mlp_layer: glu_moe + mlp_out_bias: false + normalization: rmsnorm + norm_kwargs: + eps: 1e-6 + padding: unpadded + sparse_prediction: false + + # RoPE configuration + rotary_emb_dim: null # will be set to headdim by default + rotary_emb_base: 10000.0 + rotary_emb_scale_base: null + rotary_emb_interleaved: false + + # General settings + hidden_act: gelu + init_method: full_megatron + init_std: 0.02 + init_cutoff_factor: 2.0 + init_small_embedding: false + deterministic_fa2: false + initial_attention_layer: null + initial_bert_layer: null + initial_mlp_layer: null + num_initial_layers: 0 + skip_first_prenorm: true + + # Sparse attention settings + sliding_window: 128 + global_attn_every_n_layers: 3 + unpad_embeddings: true + pad_logits: false + + # Mixture of Experts configuration + moe_num_experts: 4 + moe_top_k: 2 + moe_use_noisy_top_k: true + moe_capacity_factor: 1.25 + moe_compute_aux_loss: true + moe_load_balance_loss_weight: 0.01 + moe_router_z_loss_weight: 0.001 + +# Dataloaders +train_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: train + tokenizer_name: ${tokenizer_name} + max_seq_len: ${max_seq_len} + shuffle: true + mlm_probability: ${mlm_probability} + drop_last: true + num_workers: 8 + +eval_loader: + name: text + dataset: + local: ${data_local} + remote: ${data_remote} + split: val + tokenizer_name: ${tokenizer_name} + max_seq_len: ${max_seq_len} + shuffle: false + mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison + drop_last: false + num_workers: 8 + +# Optimization +scheduler: + name: linear_decay_with_warmup + t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration + alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration + +optimizer: + name: decoupled_adamw + lr: 5.0e-4 # Peak learning rate + betas: + - 0.9 + - 0.98 + eps: 1.0e-06 + weight_decay: 1.0e-5 # Amount of weight decay regularization + filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases + +# Training parameters - reduced for testing +max_duration: 13799838sp # Smaller sample count for testing +eval_interval: 100ba +global_train_batch_size: 8 + +# System +seed: 17 +device_train_microbatch_size: 2 +precision: amp_bf16 + +global_eval_batch_size: 16 +device_eval_microbatch_size: 4 + +# Logging +progress_bar: false +log_to_console: true +console_log_interval: 1ba + +algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0 + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + +# W&B logging +loggers: + wandb: + project: moebert + entity: # Fill this in with your W&B team/entity name + name: ${run_name} + group: rope-base-moe-c4 + + +# Checkpoint to local filesystem or remote object store +save_interval: 100000sp +save_num_checkpoints_to_keep: 2 + +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: checkpoint/{run_name} + +# (Optional) Load from local filesystem or remote object store to +# start from an existing model checkpoint; +# e.g. './ckpt/latest-rank{rank}.pt' (local), or +# 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) +# load_path: null From b764bbcff52d1701aba9256b2896cff8ace5f804 Mon Sep 17 00:00:00 2001 From: heyo66 Date: Sun, 14 Dec 2025 17:39:00 +0100 Subject: [PATCH 4/4] .... --- yamls/main/flex-bert-rope-base.yaml | 32 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/yamls/main/flex-bert-rope-base.yaml b/yamls/main/flex-bert-rope-base.yaml index d92a5e1d..b50130df 100644 --- a/yamls/main/flex-bert-rope-base.yaml +++ b/yamls/main/flex-bert-rope-base.yaml @@ -3,7 +3,7 @@ # Follow the instructions in the README to set up ./my-copy-c4 # Or point data paths to your remote C4 dataset -data_local: ./my-copy-c4 +data_local: ./c4_realnewslike data_remote: # If blank, files must be present in data_local max_seq_len: 512 @@ -112,18 +112,18 @@ optimizer: # algorithms: -max_duration: 286720000sp # Subsample the training data for ~275M samples +max_duration: 13799838sp # Subsample the training data for ~275M samples eval_interval: 2000ba -global_train_batch_size: 4096 +global_train_batch_size: 64 # System seed: 17 -device_train_microbatch_size: 128 +device_train_microbatch_size: 32 # device_train_microbatch_size: auto precision: amp_bf16 -global_eval_batch_size: 256 -device_eval_microbatch_size: 64 +global_eval_batch_size: 64 +device_eval_microbatch_size: 32 # Logging progress_bar: false @@ -139,18 +139,20 @@ callbacks: speed_monitor: window_size: 10 lr_monitor: {} +loggers: + wandb: + project: moebert + entity: # Fill this in with your W&B team/entity name + name: ${run_name} + group: rope-base-c4new -# (Optional) W&B logging -# loggers: -# wandb: -# project: # Fill this in -# entity: # Fill this in -# (Optional) Checkpoint to local filesystem or remote object store -# save_interval: 3500ba -# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -# save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) +# Checkpoint to local filesystem or remote object store +save_interval: 100000sp +save_num_checkpoints_to_keep: 2 +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: checkpoint/{run_name} # (Optional) Load from local filesystem or remote object store to # start from an existing model checkpoint; # e.g. './ckpt/latest-rank{rank}.pt' (local), or