diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 3962c771..4a6d4592 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -47,7 +47,7 @@ class Arguments: # Compute arguments. memory_optimized_mlp: bool = False mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' + mlp_impl: str = 'grouped' # Initialization arguments. fp16: bool = True