close
Skip to main content

Python module

max.nn.kernels

Helper functions for wrapping custom kv cache/attention related ops.

Any

class max.nn.kernels.Any(*args, **kwargs)

source

Bases: object

Special type indicating an unconstrained type.

  • Any is compatible with every type.
  • Any assumed to have all methods.
  • All values assumed to be instances of Any.

Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.

AttentionMaskVariant

class max.nn.kernels.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

source

Bases: str, Enum

Defines the string mask variant identifiers used in attention configuration.

CAUSAL_MASK

CAUSAL_MASK = 'causal'

source

CHUNKED_CAUSAL_MASK

CHUNKED_CAUSAL_MASK = 'chunked_causal'

source

NULL_MASK

NULL_MASK = 'null'

source

SLIDING_WINDOW_CAUSAL_MASK

SLIDING_WINDOW_CAUSAL_MASK = 'sliding_window_causal'

source

TENSOR_MASK

TENSOR_MASK = 'tensor_mask'

source

BufferValue

class max.nn.kernels.BufferValue(value)

source

Bases: Value[BufferType]

Represents a mutable semantic tensor within a Graph.

Initializes a BufferValue from another value.

Parameters:

value (Value[Any] | _Value[mo.BufferType] | HasBufferValue) – The value to wrap, either an MLIR value of buffer type or another BufferValue.

device

property device: DeviceRef

source

Returns the device of the BufferValue.

dtype

property dtype: DType

source

Returns the tensor data type.

from_mlir()

classmethod from_mlir(value)

source

Creates a BufferValue from an MLIR buffer value.

Parameters:

value (Value[BufferType]) – The MLIR buffer value to wrap.

Return type:

BufferValue

print()

print(label='debug_buffer')

source

Prints detailed information about the buffer.

Parameters:

label (str)

Return type:

None

rank

property rank: int

source

Returns the rank (number of dims) of the buffer.

shape

property shape: Shape

source

Returns the shape of the BufferValue.

type

property type: BufferType

source

Returns the type of the BufferValue as a BufferType.

DType

class max.nn.kernels.DType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

source

Bases: Enum

The tensor data type.

align

property align

source

Returns the alignment requirement of the data type in bytes.

The alignment specifies the memory boundary that values of this data type must be aligned to for optimal performance and correctness.

bfloat16

bfloat16 = 80

source

16-bit bfloat16 (Brain Float) format. 1 sign bit, 8 exponent bits, 7 mantissa bits.

bool

bool = 1

source

Boolean data type. Stores True or False values.

float16

float16 = 79

source

16-bit IEEE 754 half-precision floating-point. 1 sign bit, 5 exponent bits, 10 mantissa bits.

float32

float32 = 81

source

32-bit IEEE 754 single-precision floating-point. 1 sign bit, 8 exponent bits, 23 mantissa bits.

float4_e2m1fn

float4_e2m1fn = 64

source

4-bit floating-point with 2 exponent bits and 1 mantissa bits, finite values only.

float64

float64 = 82

source

64-bit IEEE 754 double-precision floating-point. 1 sign bit, 11 exponent bits, 52 mantissa bits.

float8_e4m3fn

float8_e4m3fn = 75

source

8-bit floating-point with 4 exponent bits and 3 mantissa bits, finite values only.

float8_e4m3fnuz

float8_e4m3fnuz = 76

source

8-bit floating-point with 4 exponent bits and 3 mantissa bits, finite values only, no negative zero.

float8_e5m2

float8_e5m2 = 77

source

8-bit floating-point with 5 exponent bits and 2 mantissa bits.

float8_e5m2fnuz

float8_e5m2fnuz = 78

source

8-bit floating-point with 5 exponent bits and 2 mantissa bits, finite values only, no negative zero.

float8_e8m0fnu

float8_e8m0fnu = 73

source

8-bit floating-point with 8 exponent bits and 0 mantissa bits, finite values only.

from_numpy()

from_numpy()

source

Converts a NumPy dtype to the corresponding DType.

Parameters:

dtype (np.dtype) – The NumPy dtype to convert.

Returns:

The corresponding DType enum value.

Return type:

DType

Raises:

ValueError – If the input dtype is not supported.

int16

int16 = 137

source

16-bit signed integer, range -32,768 to 32,767.

int32

int32 = 139

source

32-bit signed integer, range -2,147,483,648 to 2,147,483,647.

int64

int64 = 141

source

64-bit signed integer, range -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807.

int8

int8 = 135

source

8-bit signed integer, range -128 to 127.

is_float()

is_float(self) → bool

source

Checks if the data type is a floating-point type.

is_float8()

is_float8(self) → bool

source

Checks if the data type is an 8-bit floating-point type.

is_half()

is_half(self) → bool

source

Checks if the data type is a half-precision floating-point type.

is_integral()

is_integral(self) → bool

source

Checks if the data type is an integer type.

is_signed_integral()

is_signed_integral(self) → bool

source

Checks if the data type is a signed integer type.

is_unsigned_integral()

is_unsigned_integral(self) → bool

source

Checks if the data type is an unsigned integer type.

size_in_bits

property size_in_bits

source

Returns the size of the data type in bits.

This indicates how many bits are required to store a single value of this data type in memory.

size_in_bytes

property size_in_bytes

source

Returns the size of the data type in bytes.

This indicates how many bytes are required to store a single value of this data type in memory.

to_numpy()

to_numpy()

source

Converts this DType to the corresponding NumPy dtype.

Returns:

The corresponding NumPy dtype object.

Return type:

DType

Raises:

ValueError – If the dtype is not supported.

Parameters:

self (DType)

uint16

uint16 = 136

source

16-bit unsigned integer, range 0 to 65,535.

uint32

uint32 = 138

source

32-bit unsigned integer, range 0 to 4,294,967,295.

uint64

uint64 = 140

source

64-bit unsigned integer, range 0 to 18,446,744,073,709,551,615.

uint8

uint8 = 134

source

8-bit unsigned integer, range 0 to 255.

DeviceKind

class max.nn.kernels.DeviceKind(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

source

Bases: str, Enum

A device type representation.

CPU

CPU = 'cpu'

source

GPU

GPU = 'gpu'

source

from_string()

static from_string(txt)

source

Parses a device kind from its string representation.

Parameters:

txt (str)

Return type:

DeviceKind

DeviceRef

class max.nn.kernels.DeviceRef(device_type, id=0)

source

Bases: object

A symbolic device representation.

DeviceRef type representation consists of a DeviceKind and an id. This is a direct representation of the device attribute in MLIR.

The following example demonstrates how to create and use device references:

from max.graph import DeviceRef
# Create a GPU device reference (default id=0)
gpu_device = DeviceRef.GPU()
print(gpu_device)  # Outputs: gpu:0
# Create a CPU device with specific id
cpu_device = DeviceRef.CPU(id=1)
print(cpu_device)  # Outputs: cpu:1

Parameters:

CPU()

static CPU(id=0)

source

Creates a CPU device reference.

Parameters:

id (int)

Return type:

DeviceRef

GPU()

static GPU(id=0)

source

Creates a GPU device reference.

Parameters:

id (int)

Return type:

DeviceRef

device_type

device_type: DeviceKind

source

from_device()

static from_device(device)

source

Converts a Device or DeviceRef to a DeviceRef.

Parameters:

device (Device | DeviceRef)

Return type:

DeviceRef

from_mlir()

static from_mlir(attr)

source

Returns a device reference from an MLIR attribute.

Parameters:

attr (DeviceRefAttr)

Return type:

DeviceRef

id

id: int

source

is_cpu()

is_cpu()

source

Returns True if the device is a CPU device.

Return type:

bool

is_gpu()

is_gpu()

source

Returns True if the device is a GPU device.

Return type:

bool

to_device()

to_device()

source

Converts a device reference to a concrete driver Device.

Return type:

Device

to_mlir()

to_mlir()

source

Returns an MLIR attribute representing the device.

Return type:

DeviceRefAttr

Dim

class max.nn.kernels.Dim(value)

source

Bases: object

A tensor dimension.

Dims describe the shape of tensors in a Graph. In most cases, you don’t need to construct a Dim directly. Instead, you pass dimension values directly to TensorType or BufferType constructors:

from max.graph import Dim, TensorType, DeviceRef

# Create a TensorType with a symbolic "batch" dimension and a static dimension of size 10
tensor_type = TensorType(DType.int64, ("batch", 10), device=DeviceRef.CPU())

A tensor dimension can be one of three types:

  • Static: A known size. See StaticDim.
  • Symbolic: An unknown size identified by name. See SymbolicDim.
  • Algebraic: An expression derived from symbolic dimensions. See AlgebraicDim.

Static dimensions let the graph compiler resolve shapes at compile time. This enables more aggressive optimizations than symbolic or algebraic dimensions allow. That said, when tensors share a named symbolic dimension, the compiler can leverage the implied shape equality to optimize some operations.

Converts valid input values to Dim.

Parameters:

value (DimLike)

from_mlir()

static from_mlir(attr)

source

Constructs a dimension from an mlir.Attribute.

Parameters:

attr (TypedAttr) – The MLIR Attribute to parse into a dimension.

Returns:

The dimension represented by the MLIR Attr value.

Return type:

Dim

parameters

property parameters: Iterable[SymbolicDim]

source

Lists the symbolic dimension names on which this dim depends.

to_mlir()

to_mlir()

source

Creates an mlir.Attribute representing this dimension.

This is used internally when constructing tensor MLIR types.

Returns:

An mlir.Attribute in the context representing the dimension.

Return type:

TypedAttr

InputScaleSpec

class max.nn.kernels.InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None, block_size=None)

source

Bases: object

Specifies how input activations are scaled for scaled quantization.

Parameters:

activation_scale_ub

activation_scale_ub: float | None = None

source

An optional upper bound for dynamic activation scaling.

block_size

block_size: tuple[int, int] | None = None

source

The tuple[int, int] of the block size for block-wise scaling.

dtype

dtype: DType

source

The DType of the input scale factor(s).

granularity

granularity: ScaleGranularity

source

The ScaleGranularity of the input scale factor application.

is_block

property is_block: bool

source

Whether the input scale granularity is block-wise.

is_colwise

property is_colwise: bool

source

Whether the input scale granularity is column-wise.

is_rowwise

property is_rowwise: bool

source

Whether the input scale granularity is row-wise.

is_tensor

property is_tensor: bool

source

Whether the input scale granularity is per-tensor.

origin

origin: ScaleOrigin

source

The ScaleOrigin (static or dynamic) of the input scale factor.

KVCacheParams

class max.nn.kernels.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, devices, enable_prefix_caching=False, kv_connector=None, kv_connector_config=None, host_kvcache_swap_space_gb=None, page_size=128, is_mla=False, num_q_heads=None, data_parallel_degree=1, n_kv_heads_per_device=0, num_q_heads_per_device=None, kvcache_quant_config=None, num_eagle_speculative_tokens=0)

source

Bases: KVCacheParamInterface

Configuration parameters for key-value cache management in transformer models.

This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, and memory management.

Parameters:

  • dtype (DType)
  • n_kv_heads (int)
  • head_dim (int)
  • num_layers (int)
  • devices (Sequence[DeviceRef])
  • enable_prefix_caching (bool)
  • kv_connector (KVConnectorType | None)
  • kv_connector_config (Any)
  • host_kvcache_swap_space_gb (float | None)
  • page_size (int)
  • is_mla (bool)
  • num_q_heads (int | None)
  • data_parallel_degree (int)
  • n_kv_heads_per_device (int)
  • num_q_heads_per_device (int | None)
  • kvcache_quant_config (KVCacheQuantizationConfig | None)
  • num_eagle_speculative_tokens (int)

allocate_buffers()

allocate_buffers(total_num_pages)

source

Allocates the buffers for the KV cache.

Parameters:

total_num_pages (int)

Return type:

list[KVCacheBuffer]

bytes_per_block

property bytes_per_block: int

source

Returns the number of bytes per cache block.

When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.

Returns:

The number of bytes per cache block.

copy_as_dp_1()

copy_as_dp_1(replica_idx=0)

source

Creates a copy of the KVCacheParams with data parallelism disabled.

This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.

Returns:

A new KVCacheParams instance with data_parallel_degree set to 1.

Raises:

ValueError – If n_devices is not evenly divisible by data_parallel_degree.

Parameters:

replica_idx (int)

Return type:

KVCacheParams

data_parallel_degree

data_parallel_degree: int = 1

source

Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).

devices

devices: Sequence[DeviceRef]

source

Devices to use for the KV cache.

dtype

dtype: DType

source

Data type for storing key and value tensors in the cache.

dtype_shorthand

property dtype_shorthand: str

source

Returns a shorthand textual representation of the data type.

Returns:

“bf16” for bfloat16 dtype, “f32” otherwise.

enable_prefix_caching

enable_prefix_caching: bool = False

source

Whether to enable prefix caching for efficient reuse of common prompt prefixes.

get_symbolic_inputs()

get_symbolic_inputs(prefix='')

source

Computes the symbolic inputs for the KV cache.

This method returns a list of KVCacheInputs for each replica. This is used when constructing the model graph.

Returns:

The symbolic inputs for the KV cache.

Parameters:

prefix (str)

Return type:

KVCacheInputs[TensorType, BufferType]

head_dim

head_dim: int

source

Dimensionality of each attention head.

host_kvcache_swap_space_gb

host_kvcache_swap_space_gb: float | None = None

source

Amount of host memory (in GB) to reserve for KV cache swapping. Required when local or tiered connector is used.

is_fp8_kv_dtype

property is_fp8_kv_dtype: bool

source

Whether the KV cache stores FP8 data, for dispatch resolution.

Unlike quantized_kv_cache (which also requires valid scale config), this checks only the storage dtype—matching the compile-time detection in the MLA decode kernel.

TODO(SERVOPT-1094): Once SnapMLA uses a valid scale_dtype, this can be replaced by quantized_kv_cache.

is_mla

is_mla: bool = False

source

Whether the model uses Multi-Latent Attention (MLA) architecture.

kv_connector

kv_connector: KVConnectorType | None = None

source

Type of KV cache connector to use (null, local, tiered, lmcache).

kv_connector_config

kv_connector_config: Any = None

source

Connector-specific configuration (KVConnectorConfig from the pipelines layer).

kvcache_quant_config

kvcache_quant_config: KVCacheQuantizationConfig | None = None

source

KVCache quantization config. Currently only FP8 quantization supported.

n_devices

property n_devices: int

source

Returns the number of devices.

Returns:

The number of devices.

n_kv_heads

n_kv_heads: int

source

Total number of key-value attention heads across all devices.

n_kv_heads_per_device

n_kv_heads_per_device: int = 0

source

Number of KV heads allocated to each device. Computed automatically in __post_init__.

num_eagle_speculative_tokens

num_eagle_speculative_tokens: int = 0

source

Number of draft tokens to generate for EAGLE speculative decoding.

num_layers

num_layers: int

source

Number of layers in the model.

num_q_heads

num_q_heads: int | None = None

source

Number of query attention heads. Required when is_mla is True so that the attention dispatch resolver can call the MLA-specific kernel.

num_q_heads_per_device

num_q_heads_per_device: int | None = None

source

Number of query heads per device. Computed automatically in __post_init__ from num_q_heads and the parallelism configuration.

page_size

page_size: int = 128

source

Number of tokens per page (block).

This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.

Current constraints: the page size must be a multiple of 128 and at least 128.

quantized_kv_cache

property quantized_kv_cache: bool

source

Returns whether FP8 KV cache quantization is enabled.

Returns:

True when the cache dtype is float8_e4m3fn or float8_e4m3fnuz and a valid quantization scale dtype is configured; False otherwise.

shape_per_block

property shape_per_block: list[int]

source

Returns the shape of each cache block.

Returns:

The shape of the cache block.

shape_per_scale_block

property shape_per_scale_block: list[int]

source

Returns the shape of each scale block used for KVCache quantization

Returns:

The shape of the KVCache quantization scales block.

tensor_parallel_degree

property tensor_parallel_degree: int

source

Returns the tensor parallel degree.

Returns:

The tensor parallel degree.

MHAMaskVariant

class max.nn.kernels.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

source

Bases: str, Enum

Defines the integer mask variant codes used by multihead attention kernels.

CAUSAL_MASK

CAUSAL_MASK = '0'

source

CHUNKED_CAUSAL_MASK

CHUNKED_CAUSAL_MASK = '3'

source

NULL_MASK

NULL_MASK = '2'

source

SLIDING_WINDOW_CAUSAL_MASK

SLIDING_WINDOW_CAUSAL_MASK = '4'

source

MutableSequence

class max.nn.kernels.MutableSequence

source

Bases: Sequence

All the operations on a read-write sequence.

Concrete subclasses must provide __new__ or __init__, __getitem__, __setitem__, __delitem__, __len__, and insert().

append()

append(value)

source

S.append(value) – append value to the end of the sequence

clear()

clear() → None -- remove all items from S

source

extend()

extend(values)

source

S.extend(iterable) – extend sequence by appending elements from the iterable

insert()

abstract insert(index, value)

source

S.insert(index, value) – insert value before index

pop()

pop() → item -- remove and return item at index (default last).

source

Raise IndexError if list is empty or index is out of range.

remove()

remove(value)

source

S.remove(value) – remove first occurrence of value. Raise ValueError if the value is not present.

reverse()

reverse()

source

S.reverse() – reverse IN PLACE

QuantConfig

class max.nn.kernels.QuantConfig(input_scale, weight_scale, mlp_quantized_layers, attn_quantized_layers, format, embedding_output_dtype=None, bias_dtype=None, can_use_fused_mlp=False, scales_pre_interleaved=False)

source

Bases: object

Configures scaled quantization settings for a layer or model section.

For example, to configure NVFP4 block-scaled quantization for all layers in a 19-layer model:

from max.dtype import DType
from max.nn import QuantConfig, QuantFormat
from max.nn.quant_config import (
    InputScaleSpec,
    ScaleGranularity,
    ScaleOrigin,
    WeightScaleSpec,
)

all_layers = set(range(19))

input_spec = InputScaleSpec(
    granularity=ScaleGranularity.BLOCK,
    origin=ScaleOrigin.STATIC,
    dtype=DType.float32,
    block_size=(1, 16),
)
weight_spec = WeightScaleSpec(
    granularity=ScaleGranularity.BLOCK,
    dtype=DType.float8_e4m3fn,
    block_size=(1, 8),
)
config = QuantConfig(
    input_scale=input_spec,
    weight_scale=weight_spec,
    mlp_quantized_layers=all_layers,
    attn_quantized_layers=all_layers,
    format=QuantFormat.NVFP4,
)

Parameters:

attn_quantized_layers

attn_quantized_layers: set[int]

source

Set of layer indices with quantized attention projections.

Attention projections are quantized on an all-or-nothing basis per layer: either all of q_proj, k_proj, v_proj, and o_proj are quantized, or all four remain in bfloat16.

bias_dtype

bias_dtype: DType | None = None

source

The DType of bias weights.

can_use_fused_mlp

can_use_fused_mlp: bool = False

source

Whether the quantization scales can be used with fused MLP operations.

embedding_output_dtype

embedding_output_dtype: DType | None = None

source

The DType of the output from the embedding layer.

format

format: QuantFormat

source

The QuantFormat identifying the quantization format.

input_scale

input_scale: InputScaleSpec

source

InputScaleSpec for input activation scaling.

is_dynamic

property is_dynamic: bool

source

True if this input scale is dynamic.

is_fp4

property is_fp4: bool

source

True if this config represents any FP4 variant (NVFP4 or MXFP4).

is_mxfp4

property is_mxfp4: bool

source

Returns True if this config represents MXFP4 quantization.

is_nvfp4

property is_nvfp4: bool

source

True if this config represents modelopt NVFP4.

is_static

property is_static: bool

source

True if this input scale is static.

mlp_quantized_layers

mlp_quantized_layers: set[int]

source

Set of layer indices with quantized MLPs.

MLPs are quantized on an all-or-nothing basis per layer: either all of gate_proj, down_proj, and up_proj are quantized, or all three remain in bfloat16.

quantized_scales_type()

quantized_scales_type(quantized_shape, device_ref)

source

The TensorType of the scales tensor after dynamic quantization.

Parameters:

Return type:

TensorType

scales_granularity_mnk

property scales_granularity_mnk: tuple[int, int, int]

source

The weight and input scale granularities on the M, N, and K axes.

scales_pre_interleaved

scales_pre_interleaved: bool = False

source

Whether weight scales in the checkpoint are already stored in the 5D TCGEN-interleaved layout expected by the FP4 matmul kernel (NVFP4 only). Note that scales in the 5D TCGEN-interleaved layout are typically flattened to 2D [M, K//16] in the checkpoint.

weight_scale

weight_scale: WeightScaleSpec

source

WeightScaleSpec for weight scaling.

QuantizationConfig

class max.nn.kernels.QuantizationConfig(quant_method, bits, group_size, desc_act=False, sym=False)

source

Bases: object

Configuration for specifying quantization parameters that affect inference.

These parameters control how tensor values are quantized, including the method, bit precision, grouping, and other characteristics that affect the trade-off between model size, inference speed, and accuracy.

Parameters:

bits

bits: int

source

The number of bits used to represent each quantized weight element.

desc_act

desc_act: bool = False

source

Whether to use activation ordering (descending activation order). Defaults to False.

group_size

group_size: int

source

The number of weight elements that share a single set of quantization parameters.

quant_method

quant_method: str

source

The quantization method name (for example, gptq or awq).

sym

sym: bool = False

source

Whether to use symmetric quantization. Defaults to False.

QuantizationEncoding

class max.nn.kernels.QuantizationEncoding(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

source

Bases: Enum

Quantization encodings supported by MAX Graph.

Quantization reduces the precision of neural network weights to decrease memory usage and potentially improve inference speed. Each encoding represents a different compression method with specific trade-offs between model size, accuracy, and computational efficiency. These encodings are commonly used with pre-quantized model checkpoints (especially GGUF format) or can be applied during weight allocation.

The following example shows how to create a quantized weight using the Q4_K encoding:

from max.graph.quantization import QuantizationEncoding
from max.graph import Weight

# Create a quantized weight using Q4_K encoding
encoding = QuantizationEncoding.Q4_K
quantized_weight = Weight(
    name="linear.weight",
    dtype=DType.uint8,
    shape=[4096, 4096],
    device=DeviceRef.GPU(0),
    quantization_encoding=encoding
)

MAX supports several quantization formats optimized for different use cases.

Q4_0

Q4_0

source

Basic 4-bit quantization with 32 elements per block.

Q4_K

Q4_K

source

4-bit K-quantization with 256 elements per block.

Q5_K

Q5_K

source

5-bit K-quantization with 256 elements per block.

Q6_K

Q6_K

source

6-bit K-quantization with 256 elements per block.

GPTQ

GPTQ

source

Group-wise Post-Training Quantization for large language models.

GPTQ

GPTQ = 'GPTQ'

source

Q4_0

Q4_0 = 'Q4_0'

source

Q4_K

Q4_K = 'Q4_K'

source

Q5_K

Q5_K = 'Q5_K'

source

Q6_K

Q6_K = 'Q6_K'

source

block_parameters

property block_parameters: BlockParameters

source

Gets the block parameters for this quantization encoding.

Returns:

The parameters describing how elements are organized and encoded in blocks for this quantization encoding.

Return type:

BlockParameters

block_size

property block_size: int

source

Number of bytes in encoded representation of block.

All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of bytes resulting after encoding a single block.

Returns:

Size in bytes of each encoded quantization block.

Return type:

int

elements_per_block

property elements_per_block: int

source

Number of elements per block.

All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of elements gathered into a block.

Returns:

Number of original tensor elements in each quantized block.

Return type:

int

is_gguf

property is_gguf: bool

source

Checks if this quantization encoding is compatible with GGUF format.

GGUF is a format for storing large language models and compatible quantized weights.

Returns:

True if this encoding is compatible with GGUF, False otherwise.

Return type:

bool

name

property name: str

source

Gets the lowercase name of the quantization encoding.

Returns:

Lowercase string representation of the quantization encoding.

Return type:

str

StaticDim

class max.nn.kernels.StaticDim(value)

source

Bases: Dim

A static tensor dimension with a fixed size.

Because a static dimension’s size is fixed, related computation can be optimized at compile time. This is key to good model performance.

The following example creates static dimensions implicitly by passing integer values to TensorType:

from max.graph import TensorType
from max.dtype import DType
tensor = TensorType(DType.int64, (4, 5))
# This creates a tensor with 2 static dimensions: 4 and 5 respectively

Converts valid input values to Dim.

Parameters:

dim (int)

dim

dim: int

source

The size of the static dimension.

from_mlir()

static from_mlir(attr)

source

Constructs a StaticDim from a builtin.IntegerAttr.

Parameters:

attr (TypedAttr) – The builtin.IntegerAttr to parse into a StaticDim.

Returns:

The StaticDim represented by the builtin.IntegerAttr.

Return type:

StaticDim

parameters

property parameters: Iterable[SymbolicDim]

source

Lists the symbolic dimension names on which this dim depends.

to_mlir()

to_mlir()

source

Creates an mlir.Attribute representing this dimension.

This is used internally when constructing tensor MLIR types.

Returns:

An mlir.Attribute in the context representing the dimension.

Return type:

IntegerAttr

TensorType

class max.nn.kernels.TensorType(dtype, shape, device, _layout=None)

source

Bases: _TensorTypeBase[TensorType]

A symbolic tensor type.

Use TensorType to declare the expected dtype, shape, and target device of tensor values that flow through a graph during model execution. Unlike an eager tensor, a TensorType holds no data. It is a purely symbolic description of a value’s type at a specific point in the computation. The graph compiler uses this information for shape inference and optimization during graph construction.

The following example shows how to create a tensor type and access its properties:

from max.graph import TensorType, DeviceRef
from max.dtype import DType
# Create a tensor type with float32 elements and static dimensions 2x3
tensor_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
print(tensor_type.dtype)  # Outputs: DType.float32
print(tensor_type.shape)  # Outputs: [2, 3]

A shape’s dimensions can be static (integers), symbolic (strings), or algebraic (expressions over symbolic dimensions). In each case the rank is known at graph construction time.

Pass TensorType instances to load() or Module.compile() (experimental) to define the input types of a graph or model.

Parameters:

  • dtype (DType) – The data type of the tensor elements.
  • shape (Shape) – The shape of the tensor, expressed as a Shape.
  • device (DeviceRef) – The device the tensor is located on. Use DeviceRef.CPU() or DeviceRef.GPU() to create a device reference.
  • _layout (FilterLayout | None)

as_buffer()

as_buffer()

source

Returns the analogous buffer type.

Return type:

BufferType

from_mlir()

classmethod from_mlir(type)

source

Constructs a tensor type from an MLIR type.

Parameters:

type (TensorType) – The MLIR Type to parse into a tensor type.

Returns:

The tensor type represented by the MLIR Type value.

Return type:

TensorType

to_mlir()

to_mlir()

source

Converts to an mlir.Type instance.

Returns:

An mlir.Type in the specified context.

Return type:

TensorType

TensorValue

class max.nn.kernels.TensorValue(value)

source

Bases: Value[TensorType]

Represents a value semantic tensor within a Graph.

It provides various methods and properties to manipulate and query tensor attributes such as shape, data type (dtype), device placement (device), and more.

The following example demonstrates how to create and manipulate tensor values in a graph:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a sample matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("tensor_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Access tensor properties
    print(f"Shape: {tensor.shape}")  # Output: [2, 2]
    print(f"Data type: {tensor.dtype}")  # Output: DType.float32

    # Perform operations on the tensor
    transposed = tensor.T
    doubled = tensor * 2

    print(f"Original shape: {tensor.shape}")  # Output: [2, 2]
    print(f"Transposed shape: {transposed.shape}")  # Output: [2, 2]

Initializes a TensorValue from a tensor-like value.

Parameters:

value (TensorValueLike) – The value to wrap. Can be an MLIR tensor value, another TensorValue, a Dim, or a Shape.

T

property T: TensorValue

source

Returns the transposed tensor.

T is the shorthand notation for transposing. For more information, see transpose().

Returns:

A new TensorValue with swapped dimensions.

argmax()

argmax(axis=-1)

source

Reduces the tensor using an argmax operation along axis.

When the result is ambiguous ie. there are multiple maxima, selects one index arbitrarily.

from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef

# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("argmax_demo", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor

    # Argmax along axis 1 (last dimension of each row)
    indices = x.argmax(axis=1)

    print(f"Input shape: {x.shape}")       # [2, 3]
    print(f"Argmax shape: {indices.shape}")  # [2, 1]

Parameters:

axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example, -1 is the last dimension).

Returns:

A TensorValue of dtype DType.int64 with the same rank as the input, and the same shape except along axis, which will have size 1.

Return type:

TensorValue

broadcast_to()

broadcast_to(shape)

source

Broadcasts the tensor to a new shape.

The following example demonstrates how to broadcast a tensor to a larger shape:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("broadcast_to_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Broadcast tensor to a 3x2x2 tensor (add a new dimension of size 3)
    broadcasted_tensor = tensor.broadcast_to((3, 2, 2))

    print(f"Original shape: {tensor.shape}")  # Output: [2, 2]
    print(f"Broadcasted shape: {broadcasted_tensor.shape}")  # Output: [3, 2, 2]

Parameters:

shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – An iterable of integers or symbolic dimensions.

Returns:

A new TensorValue with the broadcasted shape.

Return type:

TensorValue

cast()

cast(dtype)

source

Casts a symbolic tensor to a different data type.

The following example demonstrates how to cast a tensor from one data type to another:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a matrix with float32 values
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("cast_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Cast tensor to integer type
    casted_tensor = tensor.cast(DType.int32)

    print(f"Original dtype: {tensor.dtype}")  # Output: DType.float32
    print(f"Casted dtype: {casted_tensor.dtype}")  # Output: DType.int32

Parameters:

dtype (DType) – The target data type (for example, DType.int32, DType.float64).

Returns:

A new TensorValue with the casted data type.

Return type:

TensorValue

device

property device: DeviceRef

source

Returns the device of the TensorValue.

dtype

property dtype: DType

source

Returns the tensor data type.

The following example demonstrates how to access the data type of a tensor:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a matrix with float32 values
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("dtype_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Access tensor data type
    print(f"Data type: {tensor.dtype}")  # Output: DType.float32

flatten()

flatten(start_dim=0, end_dim=-1)

source

Flattens the specified dims of a symbolic tensor.

The number and order of the elements in the tensor is unchanged. All dimensions from start_dim to end_dim (inclusive) are merged into a single output dim.

The following example demonstrates how to flatten a multi-dimensional tensor:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("flatten_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Flatten the tensor to a 1D array
    flattened_tensor = tensor.flatten()

    print(f"Original shape: {tensor.shape}")  # Output: [2, 2]
    print(f"Flattened shape: {flattened_tensor.shape}")  # Output: [4]

Parameters:

  • start_dim (int) – The starting dimension to flatten. Defaults to 0.
  • end_dim (int) – The ending dimension to flatten. Defaults to -1.

Returns:

A new TensorValue with the flattened dimensions.

Return type:

TensorValue

from_mlir()

classmethod from_mlir(value)

source

Creates a TensorValue from an MLIR tensor value.

Parameters:

value (Value[TensorType]) – The MLIR tensor value to wrap.

Return type:

TensorValue

max()

max(axis=-1)

source

Reduces the tensor using a max operation along axis.

from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef

# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("max_demo", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor

    # Max along axis 1 (last dimension of each row)
    m = x.max(axis=1)

    print(f"Input shape: {x.shape}")  # [2, 3]
    print(f"Max shape: {m.shape}")    # [2, 1]

Parameters:

axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example, -1 is the last dimension).

Returns:

A TensorValue with the same rank as the input and the same shape except along axis, which will have size 1.

Return type:

TensorValue

mean()

mean(axis=-1)

source

Reduces the tensor using a mean operation along axis.

from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef

# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("mean_demo", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor

    # Mean along axis 1 (last dimension of each row)
    mu = x.mean(axis=1)

    print(f"Input shape: {x.shape}")  # [2, 3]
    print(f"Mean shape: {mu.shape}")  # [2, 1]

Parameters:

axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example, -1 is the last dimension).

Returns:

A TensorValue with the same rank as the input and the same shape except along axis, which will have size 1.

Return type:

TensorValue

min()

min(axis=-1)

source

Reduces the tensor using a min operation along axis.

from max.dtype import DType

from max.graph import Graph, TensorType, DeviceRef

# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("min_demo", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor

    # Min along axis 1 (last dimension of each row)
    mn = x.min(axis=1)

    print(f"Input shape: {x.shape}")  # [2, 3]
    print(f"Min shape: {mn.shape}")   # [2, 1]

Parameters:

axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example, -1 is the last dimension).

Returns:

A TensorValue with the same rank as the input and the same shape except along axis, which will have size 1.

Return type:

TensorValue

permute()

permute(dims)

source

Permutes the tensor’s dimensions based on provided indices.

Parameters:

dims (list[int]) – A list of integers specifying the new order of dimensions.

Returns:

A new TensorValue with permuted dimensions.

Return type:

TensorValue

print()

print(label='debug_tensor')

source

Prints detailed information about the tensor.

Parameters:

label (str) – A string label for the printed output. Defaults to debug_tensor.

Return type:

None

rank

property rank: int

source

Returns the rank (number of dims) of the buffer.

The following example demonstrates how to access the rank of a tensor:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x2 matrix (2-dimensional array)
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("rank_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Access tensor rank (number of dimensions)
    print(f"Rank: {tensor.rank}")  # Output: 2

rebind()

rebind(shape, message='')

source

Rebinds the tensor to a new shape with error handling.

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The new shape as an iterable of integers or symbolic dimensions.
  • message (str) – (optional) A message for logging or debugging.

Returns:

A new TensorValue with the updated shape.

Return type:

TensorValue

reshape()

reshape(shape)

source

Creates a new tensor with the same data but reshaped.

The following example demonstrates how to reshape a tensor to change its dimensions:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("reshape_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Reshape tensor to a 1x4 matrix
    reshaped_tensor = tensor.reshape((1, 4))

    print(f"Original shape: {tensor.shape}")  # Output: [2, 2]
    print(f"Reshaped shape: {reshaped_tensor.shape}")  # Output: [1, 4]

Parameters:

shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The new shape as an iterable of integers or symbolic dimensions.

Returns:

A new TensorValue with the reshaped dimensions.

Return type:

TensorValue

shape

property shape: Shape

source

Returns the shape of the TensorValue.

The following example demonstrates how to access the shape of a tensor:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("shape_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Access tensor shape
    print(f"Shape: {tensor.shape}")  # Shape: [Dim(2), Dim(2)]

stdev()

stdev(axis=-1)

source

Reduces the tensor using a standard deviation operation along axis.

The standard deviation is computed as the square root of the population variance along the specified axis.

from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef

# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("stdev_demo", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor

    # Standard deviation along axis 1 (last dimension of each row)
    sd = x.stdev(axis=1)

    print(f"Input shape: {x.shape}")    # [2, 3]
    print(f"Stdev shape: {sd.shape}")  # [2, 1]

Parameters:

axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example, -1 is the last dimension).

Returns:

A TensorValue with the same rank as the input and the same shape except along axis, which will have size 1.

Return type:

TensorValue

to()

to(device)

source

Inserts a graph-level transfer to device into the compiled graph.

This is a graph execution-time operation: it records a transfer node during graph tracing that moves this symbolic tensor to device when the compiled graph runs. It is equivalent to calling transfer_to() and is typically used inside forward() to route activation tensors between devices.

This is distinct from to(), which is a pre-compilation host-side operation that moves stored weight tensors before the graph is built. If you want to place a module’s weights and computation on a device, use Module.to(device) before calling compile().

The following example demonstrates how to move a tensor from one device to another:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops, DeviceRef

# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

with Graph("to_device_example") as graph:
    # Create a tensor on the default device
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Move the tensor to a GPU device
    gpu_tensor = tensor.to(DeviceRef.GPU())

    print(f"Original device: {tensor.device}")  # Output depends on default device
    print(f"New device: {gpu_tensor.device}")  # Output: gpu:0

Parameters:

device (DeviceRef) – A DeviceRef object specifying the target device.

Returns:

A new TensorValue on the specified device.

Return type:

TensorValue

transpose()

transpose(dim_1, dim_2)

source

Swaps two dimensions of the tensor.

The following example demonstrates how to transpose a tensor by swapping its dimensions:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x3 matrix
matrix = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)

with Graph("transpose_demo") as graph:
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Transpose the tensor (swap dimensions 0 and 1)
    transposed_tensor = tensor.transpose(dim_1=0, dim_2=1)

    print(f"Original shape: {tensor.shape}")  # Output: [2, 3]
    print(f"Transposed shape: {transposed_tensor.shape}")  # Output: [3, 2]

Parameters:

  • dim_1 (int) – The first dimension to swap.
  • dim_2 (int) – The second dimension to swap.

Returns:

A new TensorValue with swapped dimensions.

Return type:

TensorValue

type

property type: TensorType

source

Returns the type of the TensorValue as a TensorType.

var()

var(axis=-1)

source

Reduces the tensor using a variance operation along axis.

The variance is computed as the mean of squared deviations from the mean (population variance, i.e., without Bessel’s correction) along the specified axis.

from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef

# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("var_demo", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor

    # Variance along axis 1 (last dimension of each row)
    vr = x.var(axis=1)

    print(f"Input shape: {x.shape}")  # [2, 3]
    print(f"Var shape: {vr.shape}")  # [2, 1]

Parameters:

axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example, -1 is the last dimension).

Returns:

A TensorValue with the same rank as the input and the same shape except along axis, which will have size 1.

Return type:

TensorValue

Type

class max.nn.kernels.Type

source

Bases: Generic[MlirType]

The type of any value in a MAX graph.

Every value in the graph has a type, and that type is represented by a Type. This type may be inspected to get finer-grained types and learn more about an individual Value.

The following example shows how to work with types in a graph:

from max.graph import Graph, TensorType
from max.dtype import DType
with Graph() as g:
    # Create a tensor constant with a specific type
    tensor_type = TensorType(DType.float32, [2, 3])
    # The type can be inspected to get information about the value
    print(f"Tensor element type: {tensor_type.dtype}")  # Outputs: DType.float32
    print(f"Tensor shape: {tensor_type.shape}")  # Outputs: [2, 3]

from_mlir()

static from_mlir(t)

source

Constructs a type from an MLIR type.

Parameters:

t (MlirType) – The MLIR Type object to parse into a type.

Returns:

The type represented by the MLIR Type value.

Return type:

Type[Any]

to_mlir()

to_mlir()

source

Converts to an mlir.Type instance.

Returns:

An mlir.Type in the specified Context.

Return type:

MlirType

Value

class max.nn.kernels.Value

source

Bases: Generic[MlirType]

Represents a symbolic value within a Graph.

A Value can represent the output of a node, the arguments of a Graph (as seen from within its body), and more generally any symbolic value available within the Graph. Other nodes receive Value values as inputs to form a computation graph.

A Value may also refer to an existing input or output of a node, and you can change them, such as by swapping a new Value.

Conceptually, think of a Value as an edge in the dataflow graph, with the other end being the user of that value.

The following example shows how to work with Values in a graph to create a simple computation:

from max.graph import Graph, ops, Value
from max.dtype import DType
import numpy as np

# Create a graph context
with Graph("value_example") as graph:
    # Create input values
    a = ops.constant(np.array([1, 2, 3]), dtype=DType.float32, device=DeviceRef.CPU())
    b = ops.constant(np.array([4, 5, 6]), dtype=DType.float32, device=DeviceRef.CPU())

    # Use values to perform operations
    c = a + b  # c is a Value representing the addition

    # Demonstrate that the result is a Value
    print(f"Type of c: {type(c)}")
    print(f"Is c a Value? {isinstance(c, Value)}")

Similar to a regular variable, a Value has a data type.

Value is abstract, it shouldn’t be constructed directly.

buffer

property buffer: BufferValue

source

Returns the Value as a BufferValue.

Raises an exception if the Value is not a BufferValue.

from_mlir()

classmethod from_mlir(value)

source

Creates a Value from an MLIR value.

Parameters:

value (Value[MlirType]) – The MLIR value to wrap.

Return type:

Value[Any]

opaque

property opaque: _OpaqueValue

source

Returns the Value as an _OpaqueValue.

Raises an exception if the Value is not a _OpaqueValue.

tensor

property tensor: TensorValue

source

Returns the Value as a TensorValue.

Raises an exception if the Value is not a TensorValue.

to_mlir()

to_mlir()

source

Converts the Value to an MLIR value.

Return type:

Value[MlirType]

type

property type: Type[MlirType]

source

Returns the type of the Value as a Type.

WeightScaleSpec

class max.nn.kernels.WeightScaleSpec(granularity, dtype, block_size=None)

source

Bases: object

Specifies how weights are scaled for scaled quantization.

Parameters:

block_size

block_size: tuple[int, int] | None = None

source

The tuple[int, int] of the block size for block-wise scaling.

dtype

dtype: DType

source

The DType of the weight scale factor(s).

granularity

granularity: ScaleGranularity

source

The ScaleGranularity of the weight scale factor application.

is_block

property is_block: bool

source

Whether the weight scale granularity is block-wise.

is_colwise

property is_colwise: bool

source

Whether the weight scale granularity is column-wise.

is_rowwise

property is_rowwise: bool

source

Whether the weight scale granularity is row-wise.

is_tensor

property is_tensor: bool

source

Whether the weight scale granularity is per-tensor.

accelerator_architecture_name()

max.nn.kernels.accelerator_architecture_name()

source

Returns the architecture name of the accelerator device.

Return type:

str

apply_penalties_to_logits()

max.nn.kernels.apply_penalties_to_logits(logits_buffer, frequency_data, frequency_offsets, *, frequency_penalty=0.0, presence_penalty=0.0, repetition_penalty=1.0)

source

Applies penalties to the logits.

Parameters:

  • logits_buffer (BufferValue) – The buffer to apply penalties to.
  • frequency_data (TensorValue) – 2d tensor of shape [unique_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token.
  • frequency_offsets (TensorValue) – 1d tensor of shape [batch_size + 1], indicating start of each sequence’s data.
  • frequency_penalty (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text: tokens will receive a penalty proportional to the count of appearances.
  • presence_penalty (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The presence penalty to apply to the model’s output A positive value will penalize new tokens that have already appeared in the generated text at least once by applying a constant penalty.
  • repetition_penalty (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in prompt and generated text at least once by dividing the logits by the repetition penalty.

Return type:

None

assert_same_device()

max.nn.kernels.assert_same_device(*values, **named_values)

source

Raises ValueError if any of the given values are not on the same device.

Parameters:

Return type:

None

batched_dynamic_scaled_fp8_matmul()

max.nn.kernels.batched_dynamic_scaled_fp8_matmul(a, b, a_scales, b_scales, input_scale_spec, weight_scale_spec, out_type=bfloat16)

source

Performs a batched blockwise scaled matmul of two tensors with scaling factors.

Parameters:

  • a (TensorValue) – The first tensor to multiply (3D tensor).
  • b (TensorValue) – The second tensor to multiply, must be transposed (3D tensor).
  • a_scales (TensorValue) – The scaling factors for the first tensor (3D tensor).
  • b_scales (TensorValue) – The scaling factors for the second tensor (3D tensor).
  • input_scale_spec (InputScaleSpec)
  • weight_scale_spec (WeightScaleSpec)
  • out_type (DType)

Returns:

The result of the matmul operation.

Return type:

TensorValue

block_scales_interleave()

max.nn.kernels.block_scales_interleave(scales, sf_vector_size=16)

source

Interleaves rank-2 FP4 block scales into the rank-5 TCGEN layout.

Parameters:

  • scales (TensorValue) – Rank-2 block scales in [M, K // sf_vector_size] layout. Supported dtypes are float8_e4m3fn for NVFP4 and float8_e8m0fnu for MXFP4.
  • sf_vector_size (int) – Scale-factor vector size: 16 for NVFP4 or 32 for MXFP4.

Returns:

The interleaved scales tensor in [ceildiv(M, 128), ceildiv(K // sf_vector_size, 4), 32, 4, 4] layout.

Return type:

TensorValue

ceildiv()

max.nn.kernels.ceildiv(n, d)

source

Ceiling division.

Parameters:

  • n (Dim) – The numerator.
  • d (Dim) – The denominator.

Returns:

The ceiling of dividing n by d.

Return type:

Dim

compute_mha_decode_num_partitions()

max.nn.kernels.compute_mha_decode_num_partitions(batch_size, max_cache_valid_length, n_kv_heads, device)

source

Computes the MHA decode partition count inside a graph.

Wraps the mo.mha.decode.get_num_partitions kernel as a graph op so that the partition heuristic can be evaluated dynamically during graph execution rather than only at graph-build time.

Parameters:

  • batch_size (TensorValue) – Scalar int64 tensor with the current batch size.
  • max_cache_valid_length (TensorValue) – Scalar int64 tensor with the maximum valid cache length across all requests.
  • n_kv_heads (int) – Number of key-value attention heads per device (compile-time constant).
  • device (DeviceRef) – The DeviceRef whose hardware info determines the partition heuristic.

Returns:

A CPU TensorValue of shape [1] and dtype int64 containing the computed partition count.

Return type:

TensorValue

compute_mla_dispatch_args_scalar()

max.nn.kernels.compute_mla_dispatch_args_scalar(batch_size, max_cache_valid_length, q_max_seq_len, num_heads, device, is_fp8_kv=False)

source

Computes scalar dispatch arguments for the MLA decode kernel.

Produces a CPU tensor of shape [3] containing pre-computed integer arguments used by the capturable MLA decode kernel variant to enable CUDA graph capture.

Parameters:

  • batch_size (TensorValue) – Scalar tensor indicating the current batch size.
  • max_cache_valid_length (TensorValue) – Scalar tensor with the maximum valid cache sequence length across all requests in the batch.
  • q_max_seq_len (TensorValue) – Scalar tensor with the maximum query sequence length in the current batch.
  • num_heads (int) – Number of query attention heads.
  • device (DeviceRef) – The DeviceRef on which to run the op.
  • is_fp8_kv (bool)

Returns:

A CPU TensorValue of shape [3] and dtype int64 containing the dispatch scalar arguments.

Return type:

TensorValue

convert_weights_to_fp8_fnuz_if_needed()

max.nn.kernels.convert_weights_to_fp8_fnuz_if_needed(weight, weight_scale)

source

Converts weights and scales to FP8 FNUZ format if needed for AMD GPUs.

This utility function checks if FP8 FNUZ conversion is needed, currently onli AMD MI300 GPUs, and performs the conversion if required. This centralizes the conversion logic that was previously duplicated across multiple files.

Parameters:

  • weight (TensorValue) – The weight tensor to potentially convert.
  • weight_scale (TensorValue) – The weight scale factor.

Returns:

Tuple of (weight, weight_scale) - converted if needed, original otherwise.

Return type:

tuple[TensorValue, TensorValue]

cross_attention_ragged()

max.nn.kernels.cross_attention_ragged(kv_params, input, input_row_offsets, kv_collection, layer_idx, mask_variant, kv_input_row_offsets, q_max_seq_len, scale, local_window_size=-1)

source

Computes cross attention provided the !mo.opaque KV Cache.

Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

attention, kv_input_row_offsets represents the KV sequence length.

Parameters:

Return type:

TensorValue

dynamic_block_scaled_matmul_fp4()

max.nn.kernels.dynamic_block_scaled_matmul_fp4(a, b, a_scales, b_scales, tensor_sf, sf_vector_size=16, out_type=bfloat16)

source

Performs a matmul of two FP4 tensors with 1D-block scaled scaling factors.

Parameters:

  • a (TensorValue) – The first tensor to multiply.
  • b (TensorValue) – The second tensor to multiply, must be transposed.
  • a_scales (TensorValue) – The scaling factors for the first tensor.
  • b_scales (TensorValue) – The scaling factors for the second tensor.
  • tensor_sf (TensorValue | float) – Buffer-wise scaling factor equal to weight_scale_2 * input_scale (non-inverted).
  • sf_vector_size (int)
  • out_type (DType)

Returns:

The result of the matmul operation.

Return type:

TensorValue

dynamic_block_scaled_matmul_mxfp4()

max.nn.kernels.dynamic_block_scaled_matmul_mxfp4(a, b, a_scales, b_scales, out_type=bfloat16)

source

Performs a matmul of two FP4 tensors with 1D-block scaled scaling factors.

Parameters:

  • a (TensorValue) – The first tensor to multiply.
  • b (TensorValue) – The second tensor to multiply, must be transposed.
  • a_scales (TensorValue) – The scaling factors for the first tensor.
  • b_scales (TensorValue) – The scaling factors for the second tensor.
  • out_type (DType)

Returns:

The result of the matmul operation.

Return type:

TensorValue

dynamic_scaled_matmul()

max.nn.kernels.dynamic_scaled_matmul(a, b, a_scales, b_scales, input_scale_spec, weight_scale_spec, out_type=bfloat16)

source

Performs a matmul of two tensors with scaling factors. Currently only supports channel-wise scaling for weights and per-token scaling for inputs.

Parameters:

Returns:

The result of the matmul operation.

Return type:

TensorValue

eagle_prefill_shift_tokens()

max.nn.kernels.eagle_prefill_shift_tokens(tokens, offsets, shift_next_tokens)

source

Shifts ragged tokens left by 1 per request, appending bonus tokens.

Parameters:

  • tokens (TensorValue) – Flat ragged token sequence of shape [total_seq_len], dtype int64.
  • offsets (TensorValue) – Row offsets of shape [batch_size + 1], dtype uint32.
  • shift_next_tokens (TensorValue) – One token per request of shape [batch_size], dtype int64, to append after shifting.

Returns:

Shifted (or copied) tokens with the same shape as tokens.

Return type:

TensorValue

flare_mla_decode_ragged()

max.nn.kernels.flare_mla_decode_ragged(kv_params, input, input_row_offsets, kv_collection, layer_idx, mask_variant, scale, scalar_args, *, qk_rope_dim=64)

source

Computes flash (self) attention provided the !mo.opaque KV Cache.

Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.

Parameters:

Return type:

TensorValue

flare_mla_decode_ragged_scaled()

max.nn.kernels.flare_mla_decode_ragged_scaled(kv_params, input, input_row_offsets, kv_collection, kv_scales, q_scales, layer_idx, mask_variant, scale, scalar_args, qk_rope_dim=64, per_token_scale_rope_aware=False, quantization_granularity=640)

source

MLA decode with explicit per-token KV and Q scale tensors.

Like flare_mla_decode_ragged but accepts explicit scale tensors so the per-token-scale rope-aware kernel receives real (non-identity) scales.

Parameters:

  • kv_params (KVCacheParams) – KV cache parameters.
  • input (TensorValue) – Query tensor [total_tokens, num_heads, head_dim].
  • input_row_offsets (TensorValue) – Ragged row offsets [batch_size + 1].
  • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection.
  • kv_scales (BufferValue) – Per-token KV scales buffer [num_blocks, 1, 1, page_size, 1, 1] float32.
  • q_scales (TensorValue) – Per-token Q scales tensor [total_tokens] float32.
  • layer_idx (TensorValue) – Layer index (uint32, on CPU).
  • mask_variant (MHAMaskVariant) – Attention mask variant.
  • scale (float) – Softmax scale (typically 1/sqrt(d_qk)).
  • qk_rope_dim (int) – Rope head dimension (default 64).
  • per_token_scale_rope_aware (bool) – Use FP8+BF16 interleaved layout.
  • quantization_granularity (int) – Granularity for KV scale quantization. Should equal the KV cache head_dim (640 for rope-aware).
  • scalar_args (TensorValue)

Returns:

Output tensor [total_tokens, num_heads, output_dim].

Return type:

TensorValue

flare_mla_decompress_k_cache()

max.nn.kernels.flare_mla_decompress_k_cache(kv_params, buffer_row_offsets_1d, cache_offsets_1d, buffer_length, weight, kv_collection, layer_idx, buffer_size)

source

This kernel decompresses the key cache by up-projecting latent representations into the KV space using a weight matrix.

The process involves:

  1. Copying buffer_length latent vectors from the key cache into a contiguous buffer (k_latent)
  2. Computing k = k_latent @ weight.T to obtain the decompressed keys

Returns:

A tensor of shape [buffer_size, weight.shape[0]] containing the decompressed keys. Note that only the first buffer_length tokens are valid.

Parameters:

Return type:

TensorValue

flare_mla_prefill_plan()

max.nn.kernels.flare_mla_prefill_plan(kv_params, input_row_offsets, kv_collection, layer_idx, buffer_size, max_chunks=16)

source

This kernel plans how to process a batch of sequences with varying lengths using a fixed-size buffer.

Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_size.

For each chunk (iteration), it calculates:
  • Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration
  • Parameters:

    Return type:

    tuple[TensorValue, TensorValue, TensorValue]

    flare_mla_prefill_ragged()

    max.nn.kernels.flare_mla_prefill_ragged(kv_params, input, k, v, input_row_offsets, buffer_row_offsets, cache_offsets, kv_collection, layer_idx, mask_variant, scale, qk_rope_dim=64)

    source

    Performs MLA prefill. In the MLA prefill, we need to decompress the KV tensors, as we store the latent representations in the KV cache. We will decompress the KV tensors into a fixed size buffer to avoid out-of-memory errors. In case the total cache length is greater than the buffer size, we will process the attention calculation in chunks.

    This MLA prefill kernel will return the output tensor for this iteration and the softmax info tensor for this iteration. Such tensors will be used by the next iteration of the MLA prefill kernel to continue the attention calculation.

    Parameters:

    Returns:

    The output tensor for this iteration

    Return type:

    TensorValue

    flash_attention_gpu()

    max.nn.kernels.flash_attention_gpu(q, k, v, mask_variant, scale, local_window_size=-1, valid_length=None)

    source

    Computes flash attention using GPU-optimized kernel.

    Parameters:

    • q (TensorValue) – Query tensor of shape [batch, seq_len, num_heads, head_dim]
    • k (TensorValue) – Key tensor of shape [batch, seq_len, num_heads, head_dim]
    • v (TensorValue) – Value tensor of shape [batch, seq_len, num_heads, head_dim]
    • mask_variant (MHAMaskVariant) – The mask variant to use for attention
    • scale (float) – Scaling factor for attention scores
    • local_window_size (int) – Local window size for sliding window attention
    • valid_length (TensorValue | None) – Optional tensor of shape [batch] with dtype uint32. When provided, uses the padded kernel variant that respects the valid sequence lengths for each batch element.

    Returns:

    Output tensor of shape [batch, seq_len, num_heads, head_dim]

    Return type:

    TensorValue

    flash_attention_padded_kv_cache()

    max.nn.kernels.flash_attention_padded_kv_cache(kv_params, q, kv_collection, layer_idx, valid_lengths, mask_variant, scale, local_window_size=-1)

    source

    Computes flash attention with padded inputs and paged KV cache.

    Parameters:

    • kv_params (KVCacheParams) – KV cache parameters
    • q (TensorValue) – Query tensor of shape [batch, seq_len, num_heads, head_dim]
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection
    • layer_idx (TensorValue) – Layer index for cache lookup
    • valid_lengths (TensorValue) – Buffer of shape [batch] with dtype uint32 indicating actual (non-padded) sequence lengths for each batch element
    • mask_variant (MHAMaskVariant) – The mask variant to use for attention
    • scale (float) – Scaling factor for attention scores
    • local_window_size (int) – Local window size for sliding window attention

    Returns:

    Output tensor of shape [batch, seq_len, num_heads, head_dim]

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    TensorValue

    flash_attention_ragged()

    max.nn.kernels.flash_attention_ragged(kv_params, input, input_row_offsets, kv_collection, layer_idx, mask_variant, scale, local_window_size=-1, sink_weights=None)

    source

    Computes flash (self) attention provided the !mo.opaque KV Cache.

    Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

    Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.

    Parameters:

    • kv_params (KVCacheParams) – KVCacheParams object containing key-value cache parameters.
    • input (TensorValue) – TensorValue representing the input tensor with shape [total_seq_len, hidden_dim].
    • input_row_offsets (TensorValue) – TensorValue indicating the start and end of each batch in the input tensor with shape [batch_size + 1].
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – PagedCacheValues object for managing key-value cache.
    • layer_idx (TensorValue) – TensorValue representing the layer index, expected to have dtype uint32.
    • mask_variant (MHAMaskVariant) – MHAMaskVariant specifying the type of attention mask to use.
    • scale (float) – float value used to scale the attention scores.
    • local_window_size (int) – int specifying the size of the local attention window, default is -1 for no local window.
    • sink_weights (TensorValue | None) – Optional tensor of shape [num_heads] containing learnable sink weights for each attention head.

    Return type:

    TensorValue

    flash_attention_ragged_gpu()

    max.nn.kernels.flash_attention_ragged_gpu(q, k, v, input_row_offsets, max_seq_len, mask_variant, scale, local_window_size=-1)

    source

    Computes flash attention for ragged inputs using GPU-optimized kernel without a KV cache.

    Parameters:

    • q (TensorValue) – Query tensor of shape [total_seq_len, num_heads, head_dim] (ragged)
    • k (TensorValue) – Key tensor of shape [total_seq_len, num_heads, head_dim] (ragged)
    • v (TensorValue) – Value tensor of shape [total_seq_len, num_heads, head_dim] (ragged)
    • input_row_offsets (TensorValue) – Buffer of shape [batch_size + 1] with dtype uint32. Indicates where each sequence starts and ends in the ragged tensors. The values should be a prefix sum (cumulative sum) of sequence lengths.
    • mask_variant (MHAMaskVariant) – The mask variant to use for attention
    • scale (float) – Scaling factor for attention scores
    • local_window_size (int) – Local window size for sliding window attention
    • max_seq_len (TensorValue)

    Returns:

    Output tensor of shape [total_seq_len, num_heads, head_dim]

    Return type:

    TensorValue

    fused_qk_padded_rope()

    max.nn.kernels.fused_qk_padded_rope(kv_params, input, kv_collection, freqs_cis, layer_idx, valid_lengths, interleaved=True)

    source

    Computes fused query-key RoPE with padded inputs and paged KV cache.

    This function applies Rotary Positional Embeddings (RoPE) to both Q and K tensors, where K is stored in the paged KV cache. This is the padded equivalent of fused_qk_ragged_rope.

    Parameters:

    • kv_params (KVCacheParams) – KV cache parameters.
    • input (TensorValue) – Query tensor of shape [batch, seq_len, n_heads, head_dim].
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection.
    • freqs_cis (TensorValue) – Frequency tensor of shape (max_seq_len * 2, head_dim).
    • layer_idx (TensorValue) – Layer index for KV cache (must be uint32 on CPU).
    • valid_lengths (TensorValue) – Buffer of shape [batch] containing the valid length for each sequence (must be uint32). RoPE is only applied to positions within these lengths.
    • interleaved (bool) – Whether to use interleaved RoPE pattern.

    Returns:

    Query tensor with RoPE applied, same shape as input.

    Return type:

    TensorValue

    fused_qk_ragged_rope()

    max.nn.kernels.fused_qk_ragged_rope(kv_params, input, input_row_offsets, kv_collection, freqs_cis, layer_idx, interleaved=True, position_ids=None, mrope_section=None)

    source

    Computes fused query-key attention with rotary positional encodings and ragged inputs.

    Parameters:

    • kv_params (KVCacheParams) – KV cache parameters
    • input (TensorValue) – [batch_size * seq_len, n_heads, head_dim]
    • input_row_offsets (TensorValue) – Ragged tensor offsets indicating where each batch starts and ends
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – KV cache collection
    • freqs_cis (TensorValue) – tensor of shape (max_seq_len * 2, head_dim)
    • layer_idx (TensorValue) – Layer index for KV cache
    • interleaved (bool) – Whether to use interleaved RoPE pattern
    • position_ids (TensorValue | None) – Optional ragged 2D array of position IDs. If None, defaults to cache_length + token_idx for each token. When num_sections > 1, mrope_section must be provided to indicate each section of the head_dim to apply RoPE to. Shape: [num_sections, total_seq_len]
    • mrope_section (list[int] | None) – Optional list of integers indicating the section of the head_dim to
    • position_ids. (apply RoPE to. Must be used in conjunction with)

    Return type:

    TensorValue

    input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input. If input is not of the same dtype as freqs_cis, it will be cast to the dtype of freqs_cis for the computation, and cast back to the original dtype after the computation is finished.

    When position_ids and mrope_section are provided, it replaces the default position calculation (cache_length + token_idx) with explicit position values. This is useful for 3D RoPE in models like Qwen2.5-VL that need custom position encoding.

    fused_qkv_padded_matmul()

    max.nn.kernels.fused_qkv_padded_matmul(kv_params, input, wqkv, kv_collection, layer_idx, valid_lengths, n_heads)

    source

    Computes fused query, key, and value projections with padded input.

    This is for non-ragged (padded batch) inputs where sequences may have different actual lengths but are padded to a uniform shape.

    Parameters:

    • kv_params (KVCacheParams) – KV cache parameters.
    • input (TensorValue) – Input tensor with shape [batch_size, seq_len, hidden_dim].
    • wqkv (TensorValue) – Weight tensor for Q, K, V projections.
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection.
    • layer_idx (TensorValue) – Layer index for cache lookup (must be uint32).
    • valid_lengths (TensorValue) – Buffer of shape [batch] containing the valid length for each sequence (must be uint32). K and V are only written to cache for positions within these lengths.
    • n_heads (int) – Number of attention heads.

    Returns:

    Query projections tensor. K and V projections are written to cache.

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    TensorValue

    fused_qkv_ragged_matmul()

    max.nn.kernels.fused_qkv_ragged_matmul(kv_params, input, input_row_offsets, wqkv, kv_collection, layer_idx, n_heads, bias=None, _output_dim=None)

    source

    Computes fused query, key, and value projections with ragged input.

    Parameters:

    • kv_params (KVCacheParams) – KVCacheParams object containing key-value cache parameters.
    • input (TensorValue) – TensorValue representing the input tensor with shape [total_seq_len, hidden_dim].
    • input_row_offsets (TensorValue) – TensorValue indicating the start and end of each request in the input tensor with shape [batch_size + 1].
    • wqkv (TensorValue) – The concatenated Q, K and V projection weights.
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – PagedCacheValues object for managing key-value cache.
    • layer_idx (TensorValue) – TensorValue representing the layer index, expected to have dtype uint32.
    • n_heads (int) – Number of Query attention heads.
    • bias (TensorValue | None) – Optional bias vector concatenated as [q, k, v].
    • _output_dim (int | None) – Optional output dimension. If not provided, the output dimension will be [n_heads * head_dim].

    Returns:

    Query projection tensor.

    Return type:

    TensorValue

    fused_qkv_ragged_matmul_quantized()

    max.nn.kernels.fused_qkv_ragged_matmul_quantized(kv_params, input, input_row_offsets, wqkv, kv_collection, layer_idx, n_heads, quantization_config, perm_idx=None, bias=None)

    source

    Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.

    input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Parameters:

    Return type:

    TensorValue

    grouped_dynamic_scaled_fp8_matmul()

    max.nn.kernels.grouped_dynamic_scaled_fp8_matmul(hidden_states, weight, a_scales, b_scales, expert_start_indices, expert_ids, expert_usage_stats_host, input_scale_spec, weight_scale_spec, out_type=bfloat16, tokens_padded_per_expert=False)

    source

    Grouped blockwise scaled matmul used in MoE layer.

    Perform a grouped blockwise scaled matmul of two tensors with scaling factors. hidden_states and expert_start_indices are used together to implement the ragged tensor.

    Parameters:

    • hidden_states (TensorValue) – The first tensor to multiply. (2D tensor)
    • weight (TensorValue) – The second tensor to multiply, must be transposed. (3D tensor)
    • a_scales (TensorValue) – The scaling factors for the first tensor. (2D tensor)
    • b_scales (TensorValue) – The scaling factors for the second tensor. (3D tensor)
    • expert_start_indices (TensorValue) – indicates where each group starts and ends in hidden_states.
    • expert_ids (TensorValue) – The id of the expert for each group in hidden_states.
    • expert_usage_stats_host (TensorValue) – The maximum number of tokens assigned to any expert, and the number of active experts.
    • input_scale_spec (InputScaleSpec) – The scaling granularity for the input tensor.
    • weight_scale_spec (WeightScaleSpec) – The scaling granularity for the weight tensor.
    • tokens_padded_per_expert (bool) – If True, It’s guaranteed that the number of tokens for each local expert will be padded, so that a_scales is aligned to 16 bytes. This is needed by the optimized grouped matmul kernel.
    • out_type (DType)

    Returns:

    The result of the matmul operation.

    Return type:

    TensorValue

    grouped_dynamic_scaled_mxfp4_matmul()

    max.nn.kernels.grouped_dynamic_scaled_mxfp4_matmul(hidden_states, weight, a_scales, b_scales, expert_start_indices, expert_ids, expert_usage_stats_host, out_type=bfloat16, estimated_total_m=None)

    source

    Performs grouped NVFP4 matmul for MoE layers.

    Performs a grouped matmul with MXFP4 (4-bit) quantized inputs and weights. The inputs are packed as uint8 (2 MXFP4 values per byte) with float8_e8m0fnu scaling factors. MXFP4 uses fixed 1D block scaling with 32 elements per scale factor along the K dimension.

    hidden_states and expert_start_indices together implement the ragged tensor representation for variable-length expert inputs.

    Parameters:

    • hidden_states (TensorValue) – The input activations with shape [total_tokens, K/2] where K is the unpacked hidden dimension. Dtype must be uint8 (packed MXFP4).
    • weight (TensorValue) – The expert weights with shape [num_experts, N, K/2]. Dtype must be uint8 (packed MXFP4).
    • a_scales (TensorValue) – Scaling factors for inputs with shape [num_scale_rows, K/32]. Dtype must be float8_e8m0fnu.
    • b_scales (TensorValue) – Scaling factors for weights with shape [num_experts, N, K/32]. Dtype must be float8_e8m0fnu.
    • expert_start_indices (TensorValue) – Indices indicating where each expert’s tokens start in hidden_states.
    • expert_ids (TensorValue) – The expert ID for each group.
    • expert_usage_stats_host (TensorValue) – A tensor containing [max_tokens_per_expert, num_active_experts].
    • out_type (DType) – Output dtype. Defaults to bfloat16.
    • estimated_total_m (TensorValue | None) – The estimated total number of tokens.

    Returns:

    The matmul result with shape [total_tokens, N] and dtype out_type.

    Return type:

    TensorValue

    grouped_matmul_block_scaled()

    max.nn.kernels.grouped_matmul_block_scaled(hidden_states, weight, a_scales, b_scales, expert_start_indices, a_scale_offsets, expert_ids, expert_scales, expert_usage_stats_host, out_type=bfloat16, estimated_total_m=None)

    source

    Performs grouped NVFP4 matmul for MoE layers.

    Performs a grouped matmul with NVFP4 (4-bit) quantized inputs and weights. The inputs are packed as uint8 (2 NVFP4 values per byte) with float8_e4m3fn scaling factors. NVFP4 uses fixed 1D block scaling with 16 elements per scale factor along the K dimension.

    hidden_states and expert_start_indices together implement the ragged tensor representation for variable-length expert inputs.

    Parameters:

    • hidden_states (TensorValue) – The input activations with shape [total_tokens, K/2] where K is the unpacked hidden dimension. Dtype must be uint8 (packed NVFP4).
    • weight (TensorValue) – The expert weights with shape [num_experts, N, K/2]. Dtype must be uint8 (packed NVFP4).
    • a_scales (TensorValue) – Scaling factors for inputs with shape [num_scale_rows, K_groups, 32, 4, 4]. Dtype must be float8_e4m3fn.
    • b_scales (TensorValue) – Scaling factors for weights with shape [num_experts, N_groups, K_groups, 32, 4, 4]. Dtype must be float8_e4m3fn.
    • expert_start_indices (TensorValue) – Indices indicating where each expert’s tokens start in hidden_states.
    • a_scale_offsets (TensorValue) – The offsets of the input scale tiles for each expert.
    • expert_ids (TensorValue) – The expert ID for each group.
    • expert_scales (TensorValue) – Per-expert scaling factors with shape [num_experts]. Dtype must be float32. Multiplied with the matmul output in the epilogue.
    • expert_usage_stats_host (TensorValue) – A tensor containing [max_tokens_per_expert, num_active_experts].
    • out_type (DType) – Output dtype. Defaults to bfloat16.
    • estimated_total_m (TensorValue | None) – The estimated total number of tokens.

    Returns:

    The matmul result with shape [total_tokens, N] and dtype out_type.

    Return type:

    TensorValue

    grouped_matmul_ragged()

    max.nn.kernels.grouped_matmul_ragged(hidden_states, weight, expert_start_indices, expert_ids, expert_usage_stats_host)

    source

    Grouped matmul used in MoE layer.

    hidden_states and expert_start_indices are used together to implement the ragged tensor. expert_start_indices indicates where each group starts and ends in hidden_states

    expert_ids is the id of the expert for each group in hidden_states

    expert_usage_stats_host is the maximum number of tokens assigned to any expert, and the number of active experts.

    Parameters:

    Return type:

    TensorValue

    grouped_quantize_dynamic_block_scaled_fp4()

    max.nn.kernels.grouped_quantize_dynamic_block_scaled_fp4(input, row_offsets, scales_offsets, expert_ids, sf_tensor, sf_vector_size=16, scales_type=float8_e4m3fn, out_type=uint8)

    source

    Grouped dynamic FP4 quantization for MoE experts.

    Quantizes a concatenated token tensor where different row ranges belong to different experts, each with its own tensor-wise scale factor.

    Parameters:

    • input (TensorValue) – The concatenated input tensor. Shape: [total_tokens, K], dtype bfloat16.
    • row_offsets (TensorValue) – Cumulative token offsets per expert. Shape: [num_experts + 1], dtype uint32.
    • scales_offsets (TensorValue) – Per-expert scale tile offset corrections. Shape: [num_experts], dtype uint32.
    • expert_ids (TensorValue) – Expert ID mapping (typically identity). Shape: [num_experts], dtype int32.
    • sf_tensor (TensorValue) – Per-expert tensor-wise scale factors. Shape: [num_experts], dtype float32.
    • sf_vector_size (int) – The block size for the scaling factors.
    • scales_type (DType) – Scale factor dtype. float8_e4m3fn for NVFP4.
    • out_type (DType) – Output dtype. uint8 for packed FP4.

    Returns:

    The quantized tensor [total_tokens, K // 2] and scales in rank-5 interleaved layout [total_m_tiles, K_tiles, 32, 4, 4].

    Return type:

    tuple[TensorValue, TensorValue]

    kv_cache_copy_pages_d2h()

    max.nn.kernels.kv_cache_copy_pages_d2h(device_kv_collection, device_page_ids, host_kv_blocks, host_page_ids, layer_idx, device_ref)

    source

    Copy KV cache pages from GPU to CPU for a single layer.

    Performs async GPU->CPU copy of specified pages for layer-wise KV cache offloading.

    Parameters:

    • device_kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Source KV cache on GPU.
    • device_page_ids (TensorValue) – Source page IDs to read from GPU.
    • host_kv_collection – Destination KV cache on CPU.
    • host_page_ids (TensorValue) – Destination page IDs to write to CPU. Must have same length as device_page_ids.
    • layer_idx (int) – Which layer to copy.
    • device_ref (DeviceRef) – Device for the GPU context.
    • host_kv_blocks (BufferValue)

    Return type:

    None

    kv_cache_ragged_2m_iadd()

    max.nn.kernels.kv_cache_ragged_2m_iadd(kv_params, a, kv_collection, input_row_offsets, lora_end_idx, batch_seq_len, layer_idx)

    source

    In-place add to paged KV cache with interleaved K/V layout.

    Performs an in-place addition of new key-value projections to paged KV cache. The input tensor a uses a “2M” layout where keys and values are interleaved: rows [0, m) contain keys and rows [m, 2m) contain values, where m is the number of tokens.

    Parameters:

    • kv_params (KVCacheParams) – KV cache configuration parameters.
    • a (TensorValue) – Input tensor with interleaved K/V data, shape (2*m, hidden_size) where m is the number of tokens. Rows [0, m) are keys, rows [m, 2m) are values.
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection containing cache blocks, cache lengths, lookup tables, and max lengths tensors.
    • input_row_offsets (TensorValue) – Ragged tensor offsets indicating where each batch starts and ends
    • lora_end_idx (TensorValue) – End index of LoRA token portion. Marks the boundary between LoRA sequences and base model sequences in the batch.
    • batch_seq_len (TensorValue) – Total sequence length in the batch. Used for indexing into the value portion of a.
    • layer_idx (TensorValue) – The transformer layer index to update in the KV cache.

    Raises:

    • ValueError – If a does not have rank 2.
    • ValueError – If input_row_offsets does not have rank 1.

    Return type:

    None

    kv_cache_ragged_radd()

    max.nn.kernels.kv_cache_ragged_radd(kv_params, a, kv_collection, input_row_offsets, batch_offset, layer_idx)

    source

    This function adds a tensor to a slice of the KVCache, sliced on the batch dimension.

    This expects that the requests which should be sliced out are contiguous and in the front of the tensor, and we’re only adding to the last requests in the batch.

    Parameters:

    Return type:

    None

    kv_cache_store_paged_padded()

    max.nn.kernels.kv_cache_store_paged_padded(kv_collection, x_cache, valid_lengths, layer_idx, *, key_or_value)

    source

    Stores key or value tensor into the paged KV cache (padded inputs).

    Parameters:

    Return type:

    None

    kv_cache_store_paged_ragged()

    max.nn.kernels.kv_cache_store_paged_ragged(kv_collection, x_cache, input_row_offsets, layer_idx, *, key_or_value)

    source

    Stores key or value tensor into the paged KV cache (ragged inputs).

    Parameters:

    Return type:

    None

    learnable_2d_interp_pos_emb()

    max.nn.kernels.learnable_2d_interp_pos_emb(x, weight, grid_thws, time_weight)

    source

    Applies learnable 2D interpolated position embedding (Kimi K2.5).

    For each video described by grid_thws, bicubic-interpolates weight from (H, W) to (h, w), optionally adds temporal sincos embedding when t > 1, and adds the result element-wise to x.

    Parameters:

    • x (TensorValue) – Patch embeddings of shape (L, dim).
    • weight (TensorValue) – Learnable 2D grid of shape (H, W, dim).
    • grid_thws (TensorValue) – Per-video (t, h, w) of shape (N, 3), dtype int64.
    • time_weight (TensorValue) – 1D sincos temporal embedding of shape (num_frames, dim), dtype float32.

    Returns:

    Tensor of shape (L, dim) with position embeddings added.

    Raises:

    ValueError – On invalid input shapes or dtypes.

    Return type:

    TensorValue

    lmcache_offload()

    max.nn.kernels.lmcache_offload(output, paged_cache, slot_mapping, start_token, end_token, page_size, num_kv_heads, head_dim, kv_dim, device_ref)

    source

    Offload KV cache data from paged format to external contiguous format.

    Used by LMCache connector to copy KV data from MAX’s paged cache layout to LMCache’s contiguous KV_2LTD format for external storage.

    Parameters:

    • output (BufferValue) – Output buffer [kv_dim, num_layers, num_tokens, hidden_dim] where hidden_dim = num_kv_heads * head_dim.
    • paged_cache (TensorValue) – Input tensor [total_num_blocks, kv_dim, num_layers, page_size, num_kv_heads, head_dim].
    • slot_mapping (TensorValue) – Token to slot mapping [total_tokens].
    • start_token (TensorValue) – Starting token index scalar [1].
    • end_token (TensorValue) – Ending token index (exclusive) scalar [1].
    • page_size (int) – Number of tokens per page in the paged cache.
    • num_kv_heads (int) – Number of KV attention heads.
    • head_dim (int) – Dimension of each attention head.
    • kv_dim (int) – KV dimension (2 for standard K/V, 1 for MLA).
    • device_ref (DeviceRef) – Device reference for the operation.

    Return type:

    None

    lmcache_onload()

    max.nn.kernels.lmcache_onload(paged_cache, input_tensor, slot_mapping, start_token, end_token, page_size, num_kv_heads, head_dim, kv_dim, device_ref)

    source

    Onload KV cache data from external contiguous format to paged format.

    Used by LMCache connector to copy KV data from LMCache’s contiguous KV_2LTD format into MAX’s paged cache layout.

    Parameters:

    • paged_cache (BufferValue) – Output buffer [total_num_blocks, kv_dim, num_layers, page_size, num_kv_heads, head_dim].
    • input_tensor (TensorValue) – Input tensor [kv_dim, num_layers, num_tokens, hidden_dim] where hidden_dim = num_kv_heads * head_dim.
    • slot_mapping (TensorValue) – Token to slot mapping [total_tokens].
    • start_token (TensorValue) – Starting token index scalar [1].
    • end_token (TensorValue) – Ending token index (exclusive) scalar [1].
    • page_size (int) – Number of tokens per page in the paged cache.
    • num_kv_heads (int) – Number of KV attention heads.
    • head_dim (int) – Dimension of each attention head.
    • kv_dim (int) – KV dimension (2 for standard K/V, 1 for MLA).
    • device_ref (DeviceRef) – Device reference for the operation.

    Return type:

    None

    masked_flash_attention_gpu()

    max.nn.kernels.masked_flash_attention_gpu(q, k, v, mask, scale)

    source

    Computes flash attention using a materialized additive mask.

    Parameters:

    • q (TensorValue) – Query tensor of shape [batch, q_seq_len, num_heads, head_dim]
    • k (TensorValue) – Key tensor of shape [batch, kv_seq_len, num_heads, head_dim]
    • v (TensorValue) – Value tensor of shape [batch, kv_seq_len, num_heads, head_dim]
    • mask (TensorValue) – Additive mask tensor. Rank 3 of shape [batch, q_seq_len, kv_seq_len] broadcasts across attention heads. Rank 4 of shape [batch, num_heads, q_seq_len, kv_seq_len] applies a per-head bias.
    • scale (float) – Scaling factor for attention scores.

    Returns:

    Output tensor of shape [batch, q_seq_len, num_heads, head_dim]

    Return type:

    TensorValue

    matmul_k_cache_ragged()

    max.nn.kernels.matmul_k_cache_ragged(kv_params, hidden_states, input_row_offsets, weight, kv_collection, layer_idx)

    source

    Computes key projections with ragged input.

    hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

    Parameters:

    Return type:

    None

    matmul_k_cache_ragged_scaled_float8()

    max.nn.kernels.matmul_k_cache_ragged_scaled_float8(kv_params, hidden_states, input_row_offsets, weight, input_scale, weight_scale, kv_collection, scales_granularity_mnk, layer_idx)

    source

    Computes key projections with ragged input with FP8 block scaling.

    Parameters:

    • kv_params (KVCacheParams) – KVCacheParams object containing key-value cache parameters.
    • hidden_states (TensorValue) – TensorValue representing the input tensor with shape [M=total_seq_len, K=hidden_dim].
    • input_row_offsets (TensorValue) – TensorValue indicating the start and end of each batch in the input tensor with shape [batch_size + 1].
    • weight (TensorValue) – TensorValue representing the weight tensor with shape [N=num_heads, K=hidden_dim].
    • input_scale (TensorValue) – TensorValue representing the input scale tensor with shape [ceildiv(K / BLOCK_SIZE_K), ceildiv(M / BLOCK_SIZE_M)].
    • weight_scale (TensorValue) – TensorValue representing the weight scale tensor with shape [ceildiv(N / BLOCK_SIZE_N), ceildiv(K / BLOCK_SIZE_K)].
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – PagedCacheValues object for managing key-value cache.
    • scales_granularity_mnk (tuple[int, int, int]) – tuple[int, int, int] representing the scaling (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K).
    • layer_idx (TensorValue) – TensorValue representing the layer index, expected to have dtype uint32.

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    None

    matmul_kv_cache_ragged()

    max.nn.kernels.matmul_kv_cache_ragged(kv_params, hidden_states, input_row_offsets, weight, kv_collection, layer_idx)

    source

    Computes key and value projections with ragged input.

    hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

    Parameters:

    Return type:

    None

    matmul_static_scaled_float8()

    max.nn.kernels.matmul_static_scaled_float8(input, weight, input_scale, weight_scale)

    source

    Performs a static-scaled float8 matrix multiplication.

    Computes input @ weight.T where both tensors are float8, dequantized using the provided per-tensor CPU scalar scales before accumulation. The output is always bfloat16.

    Parameters:

    • input (TensorValue) – Input tensor of rank 2 and dtype float8_e4m3fn or float8_e4m3fnuz.
    • weight (TensorValue) – Weight tensor of rank 2 and matching float8 dtype, laid out so that the K dimension matches input.shape[1].
    • input_scale (TensorValue) – Scalar scale factor for input (shape [] or [1]), must reside on CPU.
    • weight_scale (TensorValue) – Scalar scale factor for weight (shape [] or [1]), must reside on CPU.

    Returns:

    A TensorValue of shape [input.shape[0], weight.shape[0]] and dtype bfloat16.

    Raises:

    ValueError – If scale shapes are not scalar, input or weight are not rank 2, K dimensions do not match, or scales are not on CPU.

    Return type:

    TensorValue

    merge_ragged_tensors()

    max.nn.kernels.merge_ragged_tensors(a, a_row_offsets, b, b_row_offsets)

    source

    Merges two ragged tensors into a single ragged tensor.

    Both ragged tensors must have the same batch size (same number of row offsets). This function interleaves the rows from each tensor based on their row offsets.

    Parameters:

    • a (TensorValue) – The first ragged tensor of shape [total_a_rows, …].
    • a_row_offsets (TensorValue) – The row offsets of the first ragged tensor,indicating where each batch starts and ends in a.
    • b (TensorValue) – The second ragged tensor of shape [total_b_rows, …].
    • b_row_offsets (TensorValue) – The row offsets of the second ragged tensor, indicating where each batch starts and ends in b.

    Returns:

    • The merged ragged tensor with shape [total_a_rows + total_b_rows, …].
    • The merged row offsets with the same shape as input row offsets.

    Return type:

    A tuple of two tensors

    a = [1, 2, 3, 4, 5, 6]
    a_row_offsets = [0, 2, 6]
    b = [7, 8, 9, 10]
    b_row_offsets = [0, 3, 4]
    
    merged_tensor, merged_row_offsets = merge_ragged_tensors(
        a, a_row_offsets, b, b_row_offsets)
    
    merged_tensor = [1, 2, 7, 8, 9, 3, 4, 5, 6, 10]
    merged_row_offsets = [0, 5, 10]

    mla_decode_graph()

    max.nn.kernels.mla_decode_graph(q, kv, input_row_offsets, freqs_cis, kv_norm_gamma, w_uk, w_uv, kv_params, kv_collection, layer_idx, mask_variant, scale, epsilon, v_head_dim, scalar_args, *, w_uk_scale=None, w_uv_scale=None, quant_config=None)

    source

    This is a manually fused kernel that performs the following operations:

    • Apply RoPE to the query and the key cache (in-place).
    • Apply RMSNorm to the non-rope portion of the key cache (in-place).
    • Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk
    • Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2)
    • Perform MLA decode
    • Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv

    Parameters:

    • q (TensorValue) – Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim].
    • kv (TensorValue) – KV latent tensor from the first projection. Shape: [num_tokens, cache_head_dim] where cache_head_dim = kv_lora_rank + qk_rope_head_dim.
    • input_row_offsets (TensorValue) – Indicates where each request starts and ends in input. This is a 1D tensor of shape [num_batches + 1].
    • freqs_cis (TensorValue) – Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim].
    • kv_a_proj_layernorm – RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank].
    • w_uk (TensorValue) – Weight matrix for projecting q_nope to kv_latent_dim. Shape: [num_heads, kv_latent_dim, qk_nope_head_dim].
    • w_uv (TensorValue) – Weight matrix for projecting MLA decode output to v_head_dim. Shape: [num_heads, v_head_dim, kv_latent_dim].
    • kv_params (KVCacheParams) – KVCacheParams
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV Cache object.
    • layer_idx (TensorValue) – Layer index.
    • mask_variant (MHAMaskVariant) – The attention mask variant controlling masking behavior.
    • scale (float) – Scale for the attention calculation.
    • epsilon (float) – Small constant for numerical stability in RMSNorm.
    • v_head_dim (int) – Dimension of the V heads.
    • scalar_args (TensorValue) – Pre-computed dispatch scalar args (GPU buffer) for CUDA graph capture.
    • w_uk_scale (TensorValue | None) – Optional FP8 scale tensor for w_uk.
    • w_uv_scale (TensorValue | None) – Optional FP8 scale tensor for w_uv.
    • quant_config (QuantConfig | None) – Optional quantization config. When set, scales are required.
    • kv_norm_gamma (TensorValue)

    Returns:

    Tensor of shape [total_seq_len, num_heads, v_head_dim].

    Return type:

    TensorValue

    mla_fp8_index_top_k()

    max.nn.kernels.mla_fp8_index_top_k(q, q_s, input_row_offsets, k_collection, layer_idx, top_k, quantization_granularity, mask_variant=MHAMaskVariant.CAUSAL_MASK)

    source

    Computes top-k indices for MLA FP8 indexed attention scores.

    This function computes FP8 matmul between queries and cached keys (with scales), applies masking, and returns the indices of the top-k highest-scoring keys per token. Scores are aggregated (summed) across all attention heads.

    Parameters:

    • q (TensorValue) – Query tensor of shape [total_seq_len, num_heads, head_dim] in FP8.
    • q_s (TensorValue) – Query scales tensor of shape [total_seq_len, num_heads] in float32.
    • input_row_offsets (TensorValue) – Input row offsets tensor of shape [batch_size + 1].
    • k_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection. Must be FP8 quantized with scales.
    • layer_idx (TensorValue) – Layer index for cache lookup.
    • top_k (int) – Requested number of top indices per token.
    • quantization_granularity (int) – Quantization granularity for the K cache.
    • mask_variant (MHAMaskVariant) – The mask variant to use (NULL or CAUSAL_MASK).

    Returns:

    Output tensor of shape [total_seq_len, effective_k] containing top-k key indices per token, where effective_k = min(top_k, max_num_keys). Invalid positions are filled with -1.

    Return type:

    TensorValue

    mla_prefill_decode_graph()

    max.nn.kernels.mla_prefill_decode_graph(q, kv, input_row_offsets, freqs_cis, kv_norm_gamma, buffer_row_offsets, cache_offsets, buffer_length, w_k, w_uk, w_uv, kv_params, kv_collection, layer_idx, mask_variant, scale, epsilon, v_head_dim, scalar_args, *, w_k_scale=None, w_uk_scale=None, w_uv_scale=None, quant_config=None)

    source

    Fused MLA prefill/decode kernel for FP8.

    Switches between prefill and decode based on the maximum sequence length in the batch. See mla_prefill_graph and mla_decode_graph for the dedicated paths.

    Parameters:

    • q (TensorValue) – Combined query tensor with nope+rope parts.
    • kv (TensorValue) – KV latent tensor for current sequence.
    • input_row_offsets (TensorValue) – Row offsets for the batch.
    • freqs_cis (TensorValue) – RoPE frequencies tensor.
    • kv_norm_gamma (TensorValue) – RMSNorm gamma for KV cache.
    • buffer_row_offsets (TensorValue) – One-shot prefill buffer row offsets.
    • cache_offsets (TensorValue) – One-shot prefill cache offsets.
    • buffer_length (TensorValue) – One-shot prefill buffer length tensor.
    • w_k (TensorValue) – Prefill K up-projection weights.
    • w_uk (TensorValue) – Decode query-projection weights.
    • w_uv (TensorValue) – Decode output-projection / prefill V-projection weights.
    • kv_params (KVCacheParams) – KV cache parameters.
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache values.
    • layer_idx (TensorValue) – Layer index (uint32).
    • mask_variant (MHAMaskVariant) – Attention mask variant.
    • scale (float) – Attention scale.
    • epsilon (float) – RMSNorm epsilon.
    • v_head_dim (int) – Value head dimension for output tensor shape.
    • scalar_args (TensorValue) – Pre-computed dispatch scalar args (GPU buffer) for CUDA graph capture.
    • w_k_scale (TensorValue | None) – Optional FP8 scale tensor for w_k.
    • w_uk_scale (TensorValue | None) – Optional FP8 scale tensor for w_uk.
    • w_uv_scale (TensorValue | None) – Optional FP8 scale tensor for w_uv.
    • quant_config (QuantConfig | None) – Optional quantization config. When set, scales are required.

    Returns:

    Tensor of shape [total_seq_len, num_heads, v_head_dim].

    Return type:

    TensorValue

    mla_prefill_graph()

    max.nn.kernels.mla_prefill_graph(q, kv, input_row_offsets, freqs_cis, kv_norm_gamma, buffer_row_offsets, cache_offsets, buffer_length, w_k, w_uv, kv_params, kv_collection, layer_idx, mask_variant, scale, epsilon, v_head_dim, *, w_k_scale=None, w_uv_scale=None, quant_config=None)

    source

    This is a manually fused kernel that performs the following operations:

    • Apply RoPE to the query and the key cache (in-place).
    • Apply RMSNorm to the non-rope portion of the key cache (in-place).
    • Copy the KV latent values from PagedKVCache to a contiguous buffer.
    • Quantize the KV latent values to fp8.
    • Up-project the latent KV values to full K and V through two matmuls.
    • Perform MLA prefill.

    Parameters:

    • q (TensorValue) – Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim].
    • kv (TensorValue) – KV latent tensor from the first projection. Shape: [num_tokens, cache_head_dim] where cache_head_dim = kv_lora_rank + qk_rope_head_dim.
    • input_row_offsets (TensorValue) – Indicates where each request starts and ends in input. This is a 1D tensor of shape [num_batches + 1].
    • freqs_cis (TensorValue) – Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim].
    • kv_a_proj_layernorm – RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank].
    • buffer_row_offsets (TensorValue) – Indicates where each request’s KV latent values should be stored in the contiguous buffer. This is a 1D tensor of shape [num_batches + 1].
    • cache_offsets (TensorValue) – Indicates the starting token position in the KV cache from which to copy KV latent values for each request. This is a 1D tensor of shape [num_batches + 1].
    • buffer_length (TensorValue) – The total number of tokens in the KV cache. Scalar.
    • w_k (TensorValue) – Weight matrix for up-projecting latent KV values to full K. Shape: [num_heads * qk_nope_head_dim, kv_latent_dim].
    • w_uv (TensorValue) – Weight tensor for up-projecting latent KV values to full V. Shape: [num_heads, v_head_dim, kv_latent_dim].
    • kv_params (KVCacheParams) – KVCacheParams
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV Cache object.
    • layer_idx (TensorValue) – Layer index.
    • mask_variant (MHAMaskVariant) – The attention mask variant controlling masking behavior.
    • scale (float) – Scale for the attention calculation.
    • epsilon (float) – Small constant for numerical stability in RMSNorm.
    • v_head_dim (int) – Dimension of the V heads.
    • w_k_scale (TensorValue | None) – Optional FP8 scale tensor for w_k.
    • w_uv_scale (TensorValue | None) – Optional FP8 scale tensor for w_uv.
    • quant_config (QuantConfig | None) – Optional quantization config. When set, scales are required.
    • kv_norm_gamma (TensorValue)

    Returns:

    Tensor of shape [total_seq_len, num_heads, v_head_dim].

    Return type:

    TensorValue

    moe_create_indices()

    max.nn.kernels.moe_create_indices(topk_ids, num_local_experts, *, needs_scales_offset=False, scales_alignment=128)

    source

    Creates indices for the MoE layer.

    Parameters:

    • topk_ids (TensorValue) – The expert assignments for each token from the router.
    • num_local_experts (int) – The number of experts on this device.
    • needs_scales_offset (bool)
    • scales_alignment (int)

    Returns:

    • token_expert_order: The reordered token indices, grouped by assigned expert.
    • expert_start_indices: The starting index for each expert’s token group in the reordered sequence.
    • restore_token_order: The indices to restore original token ordering after expert computation.
    • expert_ids: ids of active experts selected for tokens
    • expert_usage_stats: The maximum number of tokens assigned to any expert, and the number of active experts.

    Return type:

    A tuple of five tensors

    moe_router_group_limited()

    max.nn.kernels.moe_router_group_limited(expert_scores, expert_bias, n_routed_experts, n_experts_per_tok, n_groups, topk_group, norm_weights, routed_scaling_factor)

    source

    Group limited MoE router. When n_groups > 1, selects up to topk_group expert groups, then picks n_experts_per_tok experts within those groups (DeepSeek-V3 style). When n_groups == 1, there is only one group, so group selection is skipped and routing uses the dedicated GPU single-group path (mo.moe.single.group.router, implemented as single_group_router in Mojo). In that case topk_group is not used by the kernel.

    Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/9b4e9788e4a3a731f7567338ed15d3ec549ce03b/inference/model.py#L566.

    Parameters:

    • expert_scores (TensorValue) – The scores for each expert for each token. Shape: [num_tokens, n_routed_experts].
    • expert_bias (TensorValue) – The bias for each expert. Shape: [n_routed_experts].
    • n_routed_experts (int) – The total number of experts. Must be divisible by n_groups.
    • n_experts_per_tok (int) – The number of experts to be selected per token.
    • n_groups (int) – The total number of expert groups. Must be divisible by n_routed_experts.
    • topk_group (int) – The maximum number of expert groups that a token will be routed to.
    • norm_weights (bool) – Whether to normalize the selected expert weights when n_groups > 1. When n_groups == 1, normalization is currently always enabled (norm_weights is treated as True) so behavior matches the previous graph path that always divided weights by their sum per token.
    • routed_scaling_factor (float)

    Returns:

    • expert_indices: The indices of the routed experts for each token. Shape: [num_tokens, n_experts_per_tok].
    • expert_weights: The weights of the routed experts for each token. Shape: [num_tokens, n_experts_per_tok].

    Return type:

    A tuple of two tensors

    mxfp4_dequant()

    max.nn.kernels.mxfp4_dequant(packed_weights, scales, out_type=bfloat16)

    source

    Dequantizes MXFP4 packed weights to BF16 or FP8 on GPU.

    Supports rank 2 [N, K//2] and rank 3 [E, N, K//2] inputs. For rank 3, leading dims are flattened to 2D, dequantized, and reshaped back.

    Parameters:

    • packed_weights (TensorValue) – Packed weights in uint8 (2 FP4 values per byte). Shape [N, K//2] or [E, N, K//2].
    • scales (TensorValue) – Block scales in float8_e8m0fnu. Shape [N, K//32] or [E, N, K//32].
    • out_type (DType) – Output dtype (bfloat16 or float8_e4m3fn).

    Returns:

    Dequantized tensor [N, K] or [E, N, K] in out_type.

    Return type:

    TensorValue

    needs_fp8_fnuz_conversion()

    max.nn.kernels.needs_fp8_fnuz_conversion()

    source

    Checks if FP8 E4M3FN to FNUZ conversion is needed for AMD GPUs.

    Returns:

    True if running on AMD GPU with CDNA3 architecture, False otherwise.

    Return type:

    bool

    normalize_e4m3fn_to_e4m3fnuz()

    max.nn.kernels.normalize_e4m3fn_to_e4m3fnuz(weight, weight_scale)

    source

    Converts E4M3FN weights to E4M3FNUZ format for AMD GPUs.

    This conversion is necessary because AMD GPUs use the E4M3FNUZ format while NVIDIA GPUs use E4M3FN. The key differences are:

    1. The bit pattern 10000000 (-128) represents zero in E4M3FN but NaN in E4M3FNUZ
    2. For the same bit representation, E4M3FNUZ values are half of E4M3FN values

    Parameters:

    • weight (TensorValue) – The weight tensor in E4M3FN format.
    • weight_scale (TensorValue) – The weight scale factor.

    Returns:

    Tuple of (converted_weight, adjusted_weight_scale, adjusted_input_scale).

    Return type:

    tuple[TensorValue, TensorValue]

    quantize_dynamic_block_scaled_fp4()

    max.nn.kernels.quantize_dynamic_block_scaled_fp4(input, tensor_sf, sf_vector_size=16, scales_type=float8_e4m3fn, out_type=uint8)

    source

    Dynamically quantize the input tensor to fp4-e2m1fn.

    Parameters:

    • input (TensorValue) – The input tensor to quantize. Shape: [seq_len, hidden_size]
    • tensor_sf (TensorValue | float) – The tensor-wise scale factor (inverted as per quantization kernel requirement).
    • sf_vector_size (int) – The block size for the scaling factors. 16 for NVFP4, 32 for MXFP4.
    • out_type (DType) – The type of the output tensor.
    • scales_type (DType) – The type of the scales tensor. float8_e4m3fn for NVFP4, float8_e8m0fnu for MXFP4.

    Returns:

    rank-5 interleaved on NVIDIA SM100, rank-2 [M, K // sf_vector_size] otherwise.

    Return type:

    The quantized tensor and scales. Scales layout depends on hardware

    quantize_dynamic_block_scaled_mxfp4()

    max.nn.kernels.quantize_dynamic_block_scaled_mxfp4(input, scales_type=float8_e8m0fnu, out_type=uint8)

    source

    Dynamically quantize the input tensor to fp4-e2m1fn.

    Parameters:

    • input (TensorValue) – The input tensor to quantize. Shape: [seq_len, hidden_size]
    • out_type (DType) – The type of the output tensor.
    • scales_type (DType) – The type of the scales tensor.

    Returns:

    The quantized tensor in [seq_len, hidden_size // 2] layout and the scales in [seq_len, hidden_size // 32] layout.

    Return type:

    tuple[TensorValue, TensorValue]

    quantize_dynamic_scaled_float8()

    max.nn.kernels.quantize_dynamic_scaled_float8(input, input_scale_spec, weight_scale_spec, scale_ub=1200.0, group_size_or_per_token=-1, out_type=float8_e4m3fn, scales_type=bfloat16)

    source

    Dynamically quantize the input tensor to fp8.

    Parameters:

    • input (TensorValue) – The input tensor to quantize.
    • scale_ub (float) – The upper bound of the scale factor.
    • group_size_or_per_token (int) – The group size for quantization. When set to -1, the quantization is column-wise.
    • out_type (DType) – The type of the output tensor.
    • scales_type (DType) – The type of the scales tensor.
    • input_scale_spec (InputScaleSpec)
    • weight_scale_spec (WeightScaleSpec)

    Returns:

    The quantized tensor and the scales.

    Return type:

    tuple[TensorValue, TensorValue]

    quantize_static_scaled_float8()

    max.nn.kernels.quantize_static_scaled_float8(x, scale, scale_is_inverted=True, out_type=float8_e4m3fn)

    source

    Quantizes a rank-2 tensor to float8 using a static per-tensor scale.

    Parameters:

    • x (TensorValue) – Input tensor to quantize. Must be rank 2 with dtype float16, bfloat16, or float32.
    • scale (TensorValue) – Scalar scale factor (shape [] or [1]) residing on CPU.
    • scale_is_inverted (bool) – When True (default), scale is interpreted as 1 / max_val (inverted). When False, it is the raw absolute-max scale.
    • out_type (DType) – Output dtype. Defaults to DType.float8_e4m3fn.

    Returns:

    A quantized TensorValue with shape equal to x and dtype out_type.

    Raises:

    ValueError – If scale is not a scalar, x is not rank 2, x dtype is unsupported, or scale is not on CPU.

    Return type:

    TensorValue

    quantize_tensor_dynamic_scaled_float8()

    max.nn.kernels.quantize_tensor_dynamic_scaled_float8(input, input_scale_spec, weight_scale_spec, scale_ub=1200.0, group_size_or_per_token=-1, out_type=float8_e4m3fn, scales_type=bfloat16)

    source

    Quantizes a rank-2 tensor to float8 using a dynamic per-tensor scale.

    Parameters:

    • input (TensorValue) – The input tensor to quantize.
    • scale_ub (float) – The upper bound of the scale factor.
    • group_size_or_per_token (int) – The group size for quantization. When set to -1, the quantization is column-wise.
    • out_type (DType) – The type of the output tensor.
    • scales_type (DType) – The type of the scales tensor.
    • input_scale_spec (InputScaleSpec)
    • weight_scale_spec (WeightScaleSpec)

    Returns:

    The quantized tensor and the scales.

    Return type:

    tuple[TensorValue, TensorValue]

    repack_gguf_quantized_weights()

    max.nn.kernels.repack_gguf_quantized_weights(weight, quantization_encoding)

    source

    Repacks GGUF quantized weights for the given encoding.

    Parameters:

    Return type:

    TensorValue

    rms_norm_key_cache()

    max.nn.kernels.rms_norm_key_cache(kv_params, kv_collection, gamma, epsilon, layer_idx, total_seq_len, input_row_offsets, weight_offset, rms_norm_cols=None, multiply_before_cast=True, per_head_norm=True)

    source

    This function applies RMSNorm to the _new_ entries in the KVCache.

    When per_head_norm=True (default), RMSNorm is applied separately to each head. In this mode, gamma should have size [head_dim] and normalization occurs across the head_dim dimensions within each head.

    When per_head_norm=False, RMSNorm is applied per token across all heads. In this mode, gamma should have size [n_kv_heads * head_dim] and normalization occurs across all dimensions for each token.

    The size of the gamma tensor determines how many dimensions will be normalized. If gamma’s size doesn’t match the expected size based on per_head_norm setting, rms_norm_cols must be explicitly specified to confirm the intention to normalize only a subset of dimensions.

    Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input_row_offsets to do this bookkeeping.

    Parameters:

    Return type:

    None

    rms_norm_value_cache()

    max.nn.kernels.rms_norm_value_cache(kv_params, kv_collection, gamma, epsilon, layer_idx, total_seq_len, input_row_offsets, weight_offset, rms_norm_cols=None, multiply_before_cast=True, per_head_norm=True)

    source

    Applies RMSNorm in place to the _new_ entries in the value cache. Semantics match rms_norm_key_cache(), but updates the value tensor for the layer instead of the key tensor.

    Parameters:

    Return type:

    None

    rope_ragged()

    max.nn.kernels.rope_ragged(input, input_row_offsets, start_pos, freqs_cis, *, interleaved=True)

    source

    Applies RoPE to ragged input using the standard rope kernel.

    Parameters:

    Return type:

    TensorValue

    rope_ragged_with_position_ids()

    max.nn.kernels.rope_ragged_with_position_ids(input, freqs_cis, position_ids, *, mrope_section=None, interleaved=True)

    source

    Applies RoPE using explicit position_ids (no KV cache coupling).

    Parameters:

    Return type:

    TensorValue

    rope_split_store_ragged()

    max.nn.kernels.rope_split_store_ragged(kv_params, qkv, input_row_offsets, freqs_cis, kv_collection, layer_idx, n_heads, interleaved=True, position_ids=None, mrope_section=None, fuse=True)

    source

    Apply rope to Q and K from flat QKV buffer, store K/V to cache.

    Reads from a flat QKV matmul output, applies RoPE to Q and K regions, stores K/V to the paged KV cache, and writes roped Q to the output.

    Parameters:

    • kv_params (KVCacheParams) – KV cache parameters.
    • qkv (TensorValue) – Flat QKV matmul output [total_seq_len, q_dim + k_dim + v_dim].
    • input_row_offsets (TensorValue) – Ragged offsets [batch_size + 1].
    • freqs_cis (TensorValue) – RoPE frequencies [max_seq_len, head_dim].
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache.
    • layer_idx (TensorValue) – Layer index.
    • n_heads (int) – Number of query attention heads.
    • interleaved (bool) – Whether freqs_cis uses interleaved (re, im) format.
    • position_ids (TensorValue | None) – Optional ragged 2D array of position IDs. If None, defaults to cache_length + token_idx for each token. When num_sections > 1, mrope_section must be provided. Shape: [num_sections, total_seq_len].
    • mrope_section (list[int] | None) – Optional list of ints indicating the section of the head_dim to apply RoPE to. Must be used with position_ids.
    • fuse (bool) – If True (default), emit a single fused custom op. If False, emit separate split, rope, and store ops for testing graph compiler fusion.

    Returns:

    Roped Q output [total_seq_len, n_heads * head_dim].

    Return type:

    TensorValue

    scatter_nd_skip_oob_indices()

    max.nn.kernels.scatter_nd_skip_oob_indices(input, updates, indices)

    source

    Creates a new symbolic tensor where the updates are scattered into input at specified indices.

    This differs from scatter_nd in that it handles oob indices by skipping the update for that index. Oob indices are those which fall outside of the range [-dim, dim).

    Parameters:

    Returns:

    A new symbolic tensor representing the result of the scatter_nd operation.

    Return type:

    TensorValue

    scatter_set_constant()

    max.nn.kernels.scatter_set_constant(data, indices, fill_val)

    source

    Scatters values into a tensor at specified indices.

    Parameters:

    Return type:

    None

    sgmv_kernel()

    max.nn.kernels.sgmv_kernel(input, lora, lora_ids, lora_ranks, input_row_offsets, max_lora_seq_len, lora_end_idx=None, bias=None)

    source

    Performs the SGMV kernel for LoRA. This is LoRA agnostic, meaning that we can perform LoRA A or B from this kernel call.

    Parameters:

    • input (TensorValue) – The input tensor.
    • lora (TensorValue) – The LoRA tensor.
    • lora_ids (TensorValue) – Ids of the LoRAs used for each sequence
    • lora_ranks (TensorValue) – The ranks of the LoRAs in the batch.
    • input_row_offsets (TensorValue) – The sequence offsets that use LoRA
    • max_lora_seq_len (int) – The maximum sequence length of any given LoRA in the batch
    • bias (TensorValue | None) – The LoRA bias
    • lora_end_idx (TensorValue | None)

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    sgmv_lora_kernel()

    max.nn.kernels.sgmv_lora_kernel(input, lora_a, lora_b, lora_ids, lora_ranks, grouped_row_offsets, lora_end_idx, max_lora_seq_len, bias=None)

    source

    Computes the SGMV LoRA kernel for some number of LoRAs A and B given the input.

    out = Wx + xAB

    SGMV can be explained by two independent kernels:
  • shrink -> shrinks high-dimensional tensor to low-rank tensor
    • expand -> expands low-rank tensor to high-dimensional tensor

    where v = [0, …] and y = (some output tensor)

    SGMV-shrink:
    v += xA
    SGMV-expand:
    y += vB

    Parameters:

    • input (TensorValue) – The input tensor
    • lora_a (TensorValue) – The LoRA tensor for A
    • lora_b (TensorValue) – The LoRA tensor for B
    • lora_ids (TensorValue) – Ids of the LoRAs used for each sequence
    • lora_ranks (TensorValue) – The ranks of the LoRAs in the batch.
    • grouped_row_offsets (TensorValue) – The grouped sequence offsets that use LoRA
    • max_lora_seq_len (int) – The maximum sequence length of any given LoRA in the batch
    • bias (TensorValue | None) – The LoRA bias
    • lora_end_idx (TensorValue)

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    TensorValue

    sgmv_lora_qkv_shrink()

    max.nn.kernels.sgmv_lora_qkv_shrink(input, lora_a, lora_ids, lora_grouped_offsets, lora_end_idx, max_lora_seq_len, max_rank)

    source

    LoRA shrink grouped matmul with planar Q/K/V output.

    Performs the LoRA ‘shrink’ operation for routed tokens using SGMV (segmented grouped matrix-vector multiplication). Computes [M, K] @ [G, 3*rank, K]^T per active LoRA adapter, then permutes the flat [M, 3*rank] result into a planar layout [3, M, rank] representing separate Q, K, V projections.

    Parameters:

    • input (TensorValue) – Routed activation matrix with shape (M, K), where M is the total number of tokens and K is the hidden dimension.
    • lora_a (TensorValue) – Shrink weights for all LoRA adapters, shape (G, 3*rank, K) where G is the number of adapters and rank is the LoRA rank.
    • lora_ids (TensorValue) – Expert/adapter indices for each active group, shape (num_active,). Values in range [0, G). May use -1 to indicate inactive slots.
    • lora_grouped_offsets (TensorValue) – Inclusive prefix sums of tokens per active adapter, shape (num_active + 1,). Defines per-adapter [start, end) ranges in input. Must be non-decreasing with offsets[0] == 0.
    • max_lora_seq_len (int) – Upper bound on tokens for any active adapter. Used for kernel tuning and memory allocation.
    • max_rank (int) – The maximum LoRA rank, determines output shape.
    • lora_end_idx (TensorValue)

    Returns:

    Output tensor with planar Q/K/V layout, shape (3, M, max_rank).

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    TensorValue

    sgmv_qkv_lora_kernel()

    max.nn.kernels.sgmv_qkv_lora_kernel(input, lora_a, lora_b_q, lora_b_kv, lora_ids, lora_ranks, input_row_offsets, lora_grouped_offsets, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv, kv_collection, kv_params, layer_idx, max_lora_seq_len, max_rank, bias=None)

    source

    Computes the SGMV QKV LoRA kernel for Q, K, V projections with LoRA.

    Parameters:

    • input (TensorValue) – The input tensor.
    • lora_a (TensorValue) – The LoRA A tensor.
    • lora_b_q (TensorValue) – The LoRA B tensor for Q projection.
    • lora_b_kv (TensorValue) – The LoRA B tensor for K and V projections (stacked).
    • lora_ids (TensorValue) – IDs of the LoRAs used for each sequence.
    • lora_ranks (TensorValue) – The ranks of the LoRAs in the batch.
    • input_row_offsets (TensorValue) – The sequence offsets that use LoRA.
    • lora_grouped_offsets (TensorValue) – Grouped offsets for LoRA sequences.
    • lora_end_idx (TensorValue) – End index of LoRA tokens in the batch.
    • batch_seq_len (TensorValue) – Total sequence length of the batch.
    • lora_ids_kv (TensorValue) – LoRA IDs for KV projections (with offset for V portion).
    • lora_grouped_offsets_kv (TensorValue) – Grouped offsets for KV LoRA sequences.
    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The KV cache.
    • kv_params (KVCacheParams) – The key-value cache configuration parameters.
    • layer_idx (TensorValue) – The layer index to retrieve the KV cache.
    • max_lora_seq_len (int) – The maximum sequence length of any given LoRA in the batch.
    • max_rank (int) – The maximum rank for the LoRAs.
    • bias (TensorValue | None) – Optional LoRA bias.

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    TensorValue

    sleep()

    max.nn.kernels.sleep(duration_sec, device_ref)

    source

    Sleep for the given duration in seconds.

    This kernel is supported on CPUs and GPUs. However, the timing may be completely inaccurate on AMD GPUs due to limitation of current time.sleep(…) impl.

    Parameters:

    Return type:

    None

    sliced_add()

    max.nn.kernels.sliced_add(x, y, lora_end_idx)

    source

    Adds tensors x and y element-wise for rows < lora_end_idx, otherwise copies x.

    This is used for LoRA where only some sequences have LoRA applied. For rows in [0, lora_end_idx): c = x + y For rows in [lora_end_idx, batch_seq_len): c = x

    Parameters:

    • x (TensorValue) – First input tensor.
    • y (TensorValue) – Second input tensor.
    • lora_end_idx (TensorValue) – End index of LoRA token portion (rows to apply add).

    Return type:

    TensorValue

    spatial_merge()

    max.nn.kernels.spatial_merge(input, grid_thw, hidden_size, merge_size)

    source

    Performs spatial merge operation on ragged input tensors.

    This operation merges spatial dimensions of input patches according to the grid dimensions specified in grid_thw.

    Parameters:

    • input (TensorValue) – Input tensor of shape [total_patches_in_grid, hidden_size]
    • grid_thw (TensorValue) – Grid dimensions tensor of shape [batch_size, 3] containing [t, h, w] for each batch item, where:
      • t: temporal/frame dimension
      • h: height dimension
      • w: width dimension
    • hidden_size (int) – Hidden dimension size
    • merge_size (int) – Size of spatial merge blocks (typically 2)

    Returns:

    Output tensor of shape [total_patches_in_grid, hidden_size]

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Return type:

    TensorValue

    store_k_cache_padded()

    max.nn.kernels.store_k_cache_padded(kv_collection, x_k, valid_lengths, layer_idx)

    source

    Stores the key tensor into the paged KV cache for padded inputs.

    Parameters:

    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
    • x_k (TensorValue) – The key tensor of rank 4 containing the new key projections.
    • valid_lengths (TensorValue) – Buffer of shape [batch] (dtype uint32) indicating the actual (non-padded) sequence length for each batch element.
    • layer_idx (TensorValue) – The scalar layer index (dtype uint32) identifying which transformer layer’s cache to update.

    Return type:

    None

    store_k_cache_ragged()

    max.nn.kernels.store_k_cache_ragged(kv_collection, x_k, input_row_offsets, layer_idx)

    source

    Stores the key tensor into the paged KV cache for ragged inputs.

    Parameters:

    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
    • x_k (TensorValue) – The key tensor of rank 3 containing the new key projections.
    • input_row_offsets (TensorValue) – Ragged tensor row offsets of shape [batch + 1] indicating where each sequence starts and ends. Must have dtype uint32.
    • layer_idx (TensorValue) – The scalar layer index (dtype uint32) identifying which transformer layer’s cache to update.

    Return type:

    None

    store_k_scale_cache_ragged()

    max.nn.kernels.store_k_scale_cache_ragged(kv_collection, x_k_scale, input_row_offsets, layer_idx, quantization_granularity)

    source

    Store key scale tensor into the paged KV cache.

    Parameters:

    Return type:

    None

    store_v_cache_padded()

    max.nn.kernels.store_v_cache_padded(kv_collection, x_v, valid_lengths, layer_idx)

    source

    Stores the value tensor into the paged KV cache for padded inputs.

    Parameters:

    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
    • x_v (TensorValue) – The value tensor of rank 4 containing the new value projections.
    • valid_lengths (TensorValue) – Buffer of shape [batch] (dtype uint32) indicating the actual (non-padded) sequence length for each batch element.
    • layer_idx (TensorValue) – The scalar layer index (dtype uint32) identifying which transformer layer’s cache to update.

    Return type:

    None

    store_v_cache_ragged()

    max.nn.kernels.store_v_cache_ragged(kv_collection, x_v, input_row_offsets, layer_idx)

    source

    Stores the value tensor into the paged KV cache for ragged inputs.

    Parameters:

    • kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
    • x_v (TensorValue) – The value tensor of rank 3 containing the new value projections.
    • input_row_offsets (TensorValue) – Ragged tensor row offsets of shape [batch + 1] indicating where each sequence starts and ends. Must have dtype uint32.
    • layer_idx (TensorValue) – The scalar layer index (dtype uint32) identifying which transformer layer’s cache to update.

    Return type:

    None

    topk_fused_sampling()

    max.nn.kernels.topk_fused_sampling(logits, top_k, *, temperature=1.0, max_k=None, min_top_p=None, top_p=1.0, min_p=None, seed=0)

    source

    Performs top-k sampling with temperature scaling.

    Parameters:

    Returns:

    Sampled tokens tensor of shape [batch_size, 1].

    Raises:

    ValueError – If input validation fails.

    Return type:

    TensorValue

    tpool_patch_merger()

    max.nn.kernels.tpool_patch_merger(input, grid_thws, kH, kW, max_h, max_w)

    source

    Performs temporal pooling patch merger on ragged video tokens.

    For each video in the batch, averages the input across the temporal (T) dimension and rearranges the result according to the spatial merge kernel (kH, kW). Each video’s T*H*W input tokens are reduced to H*W output tokens. All videos are concatenated contiguously in the output.

    Parameters:

    • input (TensorValue) – Input tensor of shape [total_input_tokens, D] where total_input_tokens = sum(T_i * H_i * W_i) over all videos.
    • grid_thws (TensorValue) – Grid dimensions tensor of shape [n_videos, 3] with (T, H, W) per video. Must have dtype int64.
    • kH (int) – Merge kernel height.
    • kW (int) – Merge kernel width.
    • max_h (int | TensorValue) – Maximum H across all videos in the batch (for grid sizing). May be a Python int (baked as a graph constant) or a TensorValue computed at runtime (e.g. via ops.max).
    • max_w (int | TensorValue) – Maximum W across all videos in the batch (for grid sizing). May be a Python int or a TensorValue.

    Returns:

    Output tensor of shape [sum(H_i * W_i), D].

    Raises:

    ValueError – On invalid input shapes or dtypes.

    Return type:

    TensorValue

    unfused_qkv_ragged_matmul_gguf_quantized()

    max.nn.kernels.unfused_qkv_ragged_matmul_gguf_quantized(kv_params, input, input_row_offsets, n_heads, q_weight, k_weight, v_weight, quantization_encoding_q, quantization_encoding_k, quantization_encoding_v, kv_collection, layer_idx)

    source

    Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.

    input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

    Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

    Parameters:

    Return type:

    TensorValue

    update_frequency_data()

    max.nn.kernels.update_frequency_data(frequency_data, frequency_offsets, tokens)

    source

    Updates the frequency data.

    Parameters:

    • frequency_data (BufferValue) – 2d tensor of shape [unique_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token.
    • frequency_offsets (TensorValue) – 1d tensor of shape [batch_size + 1], indicating start of each sequence’s data.
    • tokens (TensorValue) – The tokens to update the frequency data with.

    Return type:

    None