close
Skip to main content

Python class

Model

Model

class max.engine.Model

source

Bases: object

A loaded model that you can execute.

Do not instantiate this class directly. Instead, create it with InferenceSession.

capture()

capture(graph_keys, *inputs)

source

Capture execution into a device graph for caller-provided key.

Capture is best-effort and model-dependent. If the model issues capture-unsafe operations (for example, host-device synchronization), graph capture may fail. Callers should choose capture-safe execution paths.

Parameters:

Return type:

list[Buffer]

debug_verify_replay()

debug_verify_replay(graph_keys, *inputs)

source

Execute eagerly and verify the launch trace matches the captured graph.

This method validates that graph capture correctly represents eager execution by running the model and comparing kernel launch sequences against a previously captured device graph.

Parameters:

  • self (Model) – The model to debug/verify
  • graph_keys (int | Sequence[int]) – Caller-provided graph key or per-device keys identifying captured graphs.
  • inputs (Buffer) – Input buffers matching the captured input signature (same shapes and dtypes used during capture).

Raises:

  • TypeError – If graph_keys is neither an int nor a sequence of ints.
  • ValueError – If any key in graph_keys is out of uint64 range.
  • ValueError – If no input buffers are provided.
  • RuntimeError – If no graph has been captured for graph_keys.
  • RuntimeError – If the eager execution trace doesn’t match the captured graph.

Return type:

None

Example:

>>> model.capture([1, 1], input_tensor)
>>> model.debug_verify_replay([1, 1], input_tensor)  # Validates capture
>>> model.replay([1, 1], input_tensor)  # Safe to use optimized replay

devices

property devices

source

Returns the device objects used in the Model.

execute()

execute(*args)

source

Parameters:

Return type:

list[Buffer]

input_devices

property input_devices

source

Devices of the model’s input tensors, as a list of Device objects.

input_metadata

property input_metadata

source

Metadata about the model’s input tensors, as a list of TensorSpec objects.

For example, you can print the input tensor names, shapes, and dtypes:

for tensor in model.input_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

kernel_summaries

property kernel_summaries

source

Kernel fusion summaries from the compiled model.

Returns a list of strings, one per mgp.generic.execute kernel in the compiled graph. Each string describes the fused kernel composition, e.g. "Epilogue(custom__kv_rope, custom__kv_cache_store)".

output_devices

property output_devices

source

Devices of the model’s output tensors, as a list of Device objects.

output_metadata

property output_metadata

source

Metadata about the model’s output tensors, as a list of TensorSpec objects.

For example, you can print the output tensor names, shapes, and dtypes:

for tensor in model.output_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

reload()

reload(self, weights_registry: dict, /) → None

source

Reload weights into this compiled model in-place.

Reuses the compiled graph and replaces the weight buffers in-place.

Parameters:

weights_registry – Model weight names mapped to their new values.

replay()

replay(graph_keys, *inputs)

source

Replay the captured device graph for a caller-provided key.

Parameters:

Return type:

None

signature

property signature: Signature

source

Get input signature for model.