Python class
Module
Module
class max.nn.Module
Base class for model components with weight management.
Provides functionality to create custom layers and construct networks with automatic weight tracking.
The following example uses the Module class to create custom layers and build a neural network:
from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef
class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())
def __call__(self, x):
return x @ self.weight.T
class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)
def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))
model = MLP()
print(model.state_dict()) # {"up.weight": Buffer([5, 10]), ...}Constructing a graph without Module can result in name collisions
with the weights (in this example, there would be three weights with the
name Weight). With Module, you can use state_dict() or
load_state_dict() to initialize or set the weights values, and finalize
the weight names to be unique within the model.
build_subgraph()
build_subgraph(name, input_types, weight_prefix='')
Builds a subgraph encapsulating this layer’s computation.
Call this method once on a representative layer, then call the returned
subgraph once per layer using call() with a unique
prefix. This pattern lets the compiler process the layer definition
once rather than once per repetition, which significantly reduces compile
time for models with many identical layers.
Examples:
Build a subgraph from layer 0 and call it once per layer with layer-specific weights:
input_types = [hidden.type for hidden in h]
subgraph = self.layers[0].build_subgraph(
"transformer_block",
input_types=input_types,
weight_prefix="layers.0.",
)
# Call it once per layer with the correct weight prefix.
for idx in range(len(self.layers)):
outputs = ops.call(
subgraph, *h, prefix=f"layers.{idx}."
)
h = [x.tensor for x in outputs]-
Parameters:
-
- name (str) – The name of the subgraph. Must be unique within the containing graph.
- input_types (Sequence[Type[Any] | list[Type[Any]]]) – The input types for the subgraph. Pass a flat
Typefor a single tensor, or a list ofTypeobjects for a group of tensors that should be passed together (for example, KV-cache blocks). - weight_prefix (str) – A prefix string to strip from weight names before
registering them as placeholder weights. At call time, the caller
supplies the same prefix via the
prefixargument ofcall()to re-resolve each weight to the correct entry in the weights registry.
-
Returns:
-
A
Graphinstance representing the subgraph. -
Return type:
Notes:
Weights with names that start with weight_prefix are marked as
placeholders. Any call() invocation for this
subgraph must supply a matching prefix.
layer_weights
Returns a mapping from weight name to Weight for this layer.
load_state_dict()
load_state_dict(state_dict, *, override_quantization_encoding=False, weight_alignment=None, strict=True)
Sets the values of all weights in this model.
The keys in state_dict must match the fully-qualified weight
names used internally by the Module. Those names normally
follow the attribute hierarchy (e.g.
model.layers.0.self_attn.qkv_proj.weight), but a sublayer
whose _omit_module_attr_name is True is omitted
from its descendants’ FQNs. The canonical example is
StackedLinear in unfused mode, where
self.qkv_proj = StackedLinear(names=["q_proj", "k_proj", "v_proj"], stacked=False)
exposes weights at self_attn.q_proj.weight /
self_attn.k_proj.weight / self_attn.v_proj.weight rather
than nested under self_attn.qkv_proj.. Use
raw_state_dict() to inspect the exact keys this method
expects for a given module.
-
Parameters:
-
- state_dict (Mapping[str, DLPackArray | WeightData]) – A map from weight name to a numpy array or
Buffer. - override_quantization_encoding (bool) – Whether to override the weight quantization based on the loaded value.
- weight_alignment (int | None) – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state_dict must be aligned by the default dtype alignment.
- strict (bool) – If True, raises an error if any weights required by the Module are missing from state_dict, or if any keys in state_dict were not used by the Module. If False, both missing and unexpected keys are tolerated and reported only via return values/logging by callers.
- state_dict (Mapping[str, DLPackArray | WeightData]) – A map from weight name to a numpy array or
-
Raises:
-
ValueError – If strict is True and any required weight is missing from state_dict, or if state_dict contains keys not used by the Module.
-
Return type:
-
None
raw_state_dict()
raw_state_dict()
Returns all weights objects in the model.
Unlike state_dict(), this returns Weight objects instead of
the assigned values. Some parameters inside the Weight can be
configured before a graph is built. Do not change these attributes after
building a graph:
Keys follow the same FQN convention as load_state_dict():
attribute paths through the module tree, with any sublayer that
sets _omit_module_attr_name skipped in the prefix.
set_shared_weight()
set_shared_weight(name, weight)
Registers a Weight as shared on this layer.
Sets name as an attribute on this layer and marks the weight as
shared so that raw_state_dict() and load_state_dict() skip
it when iterating over owned weights.
state_dict()
state_dict(auto_initialize=True)
Returns values of all weights in the model.
The values returned are the same as the values set in load_state_dict().
If load_state_dict() has not been called and none of the weights have
values, then they are initialized to zero.
Keys follow the same FQN convention as load_state_dict():
attribute paths through the module tree, with any sublayer that
sets _omit_module_attr_name (e.g.
StackedLinear in unfused mode) skipped in the
prefix.
sublayers
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!