Skip to content
GitHub
View on GitHub

ModelArchitecture

API reference for ModelArchitecture

from modal_training_gym.common.models.base import ModelArchitecture

Transformer architecture parameters for a specific model.

FieldTypeDefaultDescription
num_layersint0Number of transformer layers. Default 0.
hidden_sizeint0Hidden dimension size. Default 0.
ffn_hidden_sizeint0Feed-forward network intermediate size. Default 0.
vocab_sizeint0Vocabulary size. Default 0.
FieldTypeDefaultDescription
num_attention_headsint0Number of attention heads. Default 0.
group_query_attentionboolTrueEnable grouped-query attention (GQA). Default True.
num_query_groupsint0Number of KV head groups for GQA. Default 0.
kv_channelsint0Per-head key/value channel dimension. Default 0.
FieldTypeDefaultDescription
normalizationstr"RMSNorm"Layer normalization type. Default "RMSNorm".
norm_epsilonfloat1e-06Normalization epsilon. Default 1e-6.
swigluboolTrueUse SwiGLU activation in FFN. Default True.
disable_bias_linearboolTrueDisable bias in linear layers. Default True.
qk_layernormboolTrueApply layer norm to query and key projections. Default True.
FieldTypeDefaultDescription
use_rotary_position_embeddingsboolTrueUse RoPE positional encoding. Default True.
rotary_baseint10000Base frequency for RoPE. Default 10000.

Generate Megatron-LM CLI flags from this architecture spec.

Source: modal_training_gym/common/models/base.py