close
Skip to main content

Python class

Module

Module

class max.nn.Module

source

Bases: Layer, ABC

Base class for model components with weight management.

Provides functionality to create custom layers and construct networks with automatic weight tracking.

The following example uses the Module class to create custom layers and build a neural network:

from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef

class Linear(nn.Module):
    def __init__(self, in_dims, out_dims):
        super().__init__()
        self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())

    def __call__(self, x):
        return x @ self.weight.T

class MLP(nn.Module):
    def __init__(self):
        self.up = Linear(5, 10)
        self.gate = Linear(5, 10)
        self.down = Linear(10, 5)

    def __call__(self, x):
        return self.down(ops.silu(self.gate(x)) + self.up(x))

model = MLP()
print(model.state_dict())  # {"up.weight": Buffer([5, 10]), ...}

Constructing a graph without Module can result in name collisions with the weights (in this example, there would be three weights with the name Weight). With Module, you can use state_dict() or load_state_dict() to initialize or set the weights values, and finalize the weight names to be unique within the model.

build_subgraph()

build_subgraph(name, input_types, weight_prefix='')

source

Builds a subgraph encapsulating this layer’s computation.

Call this method once on a representative layer, then call the returned subgraph once per layer using call() with a unique prefix. This pattern lets the compiler process the layer definition once rather than once per repetition, which significantly reduces compile time for models with many identical layers.

Examples:

Build a subgraph from layer 0 and call it once per layer with layer-specific weights:

input_types = [hidden.type for hidden in h]

subgraph = self.layers[0].build_subgraph(
    "transformer_block",
    input_types=input_types,
    weight_prefix="layers.0.",
)

# Call it once per layer with the correct weight prefix.
for idx in range(len(self.layers)):
    outputs = ops.call(
        subgraph, *h, prefix=f"layers.{idx}."
    )
    h = [x.tensor for x in outputs]

Parameters:

  • name (str) – The name of the subgraph. Must be unique within the containing graph.
  • input_types (Sequence[Type[Any] | list[Type[Any]]]) – The input types for the subgraph. Pass a flat Type for a single tensor, or a list of Type objects for a group of tensors that should be passed together (for example, KV-cache blocks).
  • weight_prefix (str) – A prefix string to strip from weight names before registering them as placeholder weights. At call time, the caller supplies the same prefix via the prefix argument of call() to re-resolve each weight to the correct entry in the weights registry.

Returns:

A Graph instance representing the subgraph.

Return type:

Graph

Notes:

Weights with names that start with weight_prefix are marked as placeholders. Any call() invocation for this subgraph must supply a matching prefix.

layer_weights

property layer_weights: dict[str, Weight]

source

Returns a mapping from weight name to Weight for this layer.

load_state_dict()

load_state_dict(state_dict, *, override_quantization_encoding=False, weight_alignment=None, strict=True)

source

Sets the values of all weights in this model.

The keys in state_dict must match the fully-qualified weight names used internally by the Module. Those names normally follow the attribute hierarchy (e.g. model.layers.0.self_attn.qkv_proj.weight), but a sublayer whose _omit_module_attr_name is True is omitted from its descendants’ FQNs. The canonical example is StackedLinear in unfused mode, where self.qkv_proj = StackedLinear(names=["q_proj", "k_proj", "v_proj"], stacked=False) exposes weights at self_attn.q_proj.weight / self_attn.k_proj.weight / self_attn.v_proj.weight rather than nested under self_attn.qkv_proj.. Use raw_state_dict() to inspect the exact keys this method expects for a given module.

Parameters:

  • state_dict (Mapping[str, DLPackArray | WeightData]) – A map from weight name to a numpy array or Buffer.
  • override_quantization_encoding (bool) – Whether to override the weight quantization based on the loaded value.
  • weight_alignment (int | None) – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state_dict must be aligned by the default dtype alignment.
  • strict (bool) – If True, raises an error if any weights required by the Module are missing from state_dict, or if any keys in state_dict were not used by the Module. If False, both missing and unexpected keys are tolerated and reported only via return values/logging by callers.

Raises:

ValueError – If strict is True and any required weight is missing from state_dict, or if state_dict contains keys not used by the Module.

Return type:

None

raw_state_dict()

raw_state_dict()

source

Returns all weights objects in the model. Unlike state_dict(), this returns Weight objects instead of the assigned values. Some parameters inside the Weight can be configured before a graph is built. Do not change these attributes after building a graph:

Keys follow the same FQN convention as load_state_dict(): attribute paths through the module tree, with any sublayer that sets _omit_module_attr_name skipped in the prefix.

Returns:

Map from weight name to the Weight object.

Return type:

dict[str, Weight]

set_shared_weight()

set_shared_weight(name, weight)

source

Registers a Weight as shared on this layer.

Sets name as an attribute on this layer and marks the weight as shared so that raw_state_dict() and load_state_dict() skip it when iterating over owned weights.

Parameters:

  • name (str) – The attribute name under which the weight is registered.
  • weight (Weight) – The Weight to share.

Return type:

None

state_dict()

state_dict(auto_initialize=True)

source

Returns values of all weights in the model.

The values returned are the same as the values set in load_state_dict(). If load_state_dict() has not been called and none of the weights have values, then they are initialized to zero.

Keys follow the same FQN convention as load_state_dict(): attribute paths through the module tree, with any sublayer that sets _omit_module_attr_name (e.g. StackedLinear in unfused mode) skipped in the prefix.

Parameters:

auto_initialize (bool) – Determines whether to initialize weights to zero if the weight value has not been loaded. If this is False, a ValueError is raised if an uninitialized weight is found.

Returns:

Map from weight name to the weight value (can be numpy array or Buffer).

Return type:

dict[str, DLPackArray]

sublayers

property sublayers: dict[str, Module]

source