Python module
max.nn.kernels
Helper functions for wrapping custom kv cache/attention related ops.
Any
class max.nn.kernels.Any(*args, **kwargs)
Bases: object
Special type indicating an unconstrained type.
- Any is compatible with every type.
- Any assumed to have all methods.
- All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
AttentionMaskVariant
class max.nn.kernels.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Defines the string mask variant identifiers used in attention configuration.
CAUSAL_MASK
CAUSAL_MASK = 'causal'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = 'chunked_causal'
NULL_MASK
NULL_MASK = 'null'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = 'sliding_window_causal'
TENSOR_MASK
TENSOR_MASK = 'tensor_mask'
BufferValue
class max.nn.kernels.BufferValue(value)
Bases: Value[BufferType]
Represents a mutable semantic tensor within a Graph.
Initializes a BufferValue from another value.
-
Parameters:
-
value (Value[Any] | _Value[mo.BufferType] | HasBufferValue) – The value to wrap, either an MLIR value of buffer type or another
BufferValue.
device
property device: DeviceRef
Returns the device of the BufferValue.
dtype
property dtype: DType
Returns the tensor data type.
from_mlir()
classmethod from_mlir(value)
Creates a BufferValue from an MLIR buffer value.
-
Parameters:
-
value (Value[BufferType]) – The MLIR buffer value to wrap.
-
Return type:
print()
print(label='debug_buffer')
Prints detailed information about the buffer.
-
Parameters:
-
label (str)
-
Return type:
-
None
rank
property rank: int
Returns the rank (number of dims) of the buffer.
shape
property shape: Shape
Returns the shape of the BufferValue.
type
property type: BufferType
Returns the type of the BufferValue as a BufferType.
DType
class max.nn.kernels.DType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Bases: Enum
The tensor data type.
align
property align
Returns the alignment requirement of the data type in bytes.
The alignment specifies the memory boundary that values of this data type must be aligned to for optimal performance and correctness.
bfloat16
bfloat16 = 80
16-bit bfloat16 (Brain Float) format. 1 sign bit, 8 exponent bits, 7 mantissa bits.
bool
bool = 1
Boolean data type. Stores True or False values.
float16
float16 = 79
16-bit IEEE 754 half-precision floating-point. 1 sign bit, 5 exponent bits, 10 mantissa bits.
float32
float32 = 81
32-bit IEEE 754 single-precision floating-point. 1 sign bit, 8 exponent bits, 23 mantissa bits.
float4_e2m1fn
float4_e2m1fn = 64
4-bit floating-point with 2 exponent bits and 1 mantissa bits, finite values only.
float64
float64 = 82
64-bit IEEE 754 double-precision floating-point. 1 sign bit, 11 exponent bits, 52 mantissa bits.
float8_e4m3fn
float8_e4m3fn = 75
8-bit floating-point with 4 exponent bits and 3 mantissa bits, finite values only.
float8_e4m3fnuz
float8_e4m3fnuz = 76
8-bit floating-point with 4 exponent bits and 3 mantissa bits, finite values only, no negative zero.
float8_e5m2
float8_e5m2 = 77
8-bit floating-point with 5 exponent bits and 2 mantissa bits.
float8_e5m2fnuz
float8_e5m2fnuz = 78
8-bit floating-point with 5 exponent bits and 2 mantissa bits, finite values only, no negative zero.
float8_e8m0fnu
float8_e8m0fnu = 73
8-bit floating-point with 8 exponent bits and 0 mantissa bits, finite values only.
from_numpy()
from_numpy()
Converts a NumPy dtype to the corresponding DType.
-
Parameters:
-
dtype (np.dtype) – The NumPy dtype to convert.
-
Returns:
-
The corresponding DType enum value.
-
Return type:
-
Raises:
-
ValueError – If the input dtype is not supported.
int16
int16 = 137
16-bit signed integer, range -32,768 to 32,767.
int32
int32 = 139
32-bit signed integer, range -2,147,483,648 to 2,147,483,647.
int64
int64 = 141
64-bit signed integer, range -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807.
int8
int8 = 135
8-bit signed integer, range -128 to 127.
is_float()
is_float(self) → bool
Checks if the data type is a floating-point type.
is_float8()
is_float8(self) → bool
Checks if the data type is an 8-bit floating-point type.
is_half()
is_half(self) → bool
Checks if the data type is a half-precision floating-point type.
is_integral()
is_integral(self) → bool
Checks if the data type is an integer type.
is_signed_integral()
is_signed_integral(self) → bool
Checks if the data type is a signed integer type.
is_unsigned_integral()
is_unsigned_integral(self) → bool
Checks if the data type is an unsigned integer type.
size_in_bits
property size_in_bits
Returns the size of the data type in bits.
This indicates how many bits are required to store a single value of this data type in memory.
size_in_bytes
property size_in_bytes
Returns the size of the data type in bytes.
This indicates how many bytes are required to store a single value of this data type in memory.
to_numpy()
to_numpy()
Converts this DType to the corresponding NumPy dtype.
-
Returns:
-
The corresponding NumPy dtype object.
-
Return type:
-
Raises:
-
ValueError – If the dtype is not supported.
-
Parameters:
-
self (DType)
uint16
uint16 = 136
16-bit unsigned integer, range 0 to 65,535.
uint32
uint32 = 138
32-bit unsigned integer, range 0 to 4,294,967,295.
uint64
uint64 = 140
64-bit unsigned integer, range 0 to 18,446,744,073,709,551,615.
uint8
uint8 = 134
8-bit unsigned integer, range 0 to 255.
DeviceKind
class max.nn.kernels.DeviceKind(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
A device type representation.
CPU
CPU = 'cpu'
GPU
GPU = 'gpu'
from_string()
static from_string(txt)
Parses a device kind from its string representation.
-
Parameters:
-
txt (str)
-
Return type:
DeviceRef
class max.nn.kernels.DeviceRef(device_type, id=0)
Bases: object
A symbolic device representation.
DeviceRef type representation consists of a DeviceKind and an id. This is a direct representation of the device attribute in MLIR.
The following example demonstrates how to create and use device references:
from max.graph import DeviceRef
# Create a GPU device reference (default id=0)
gpu_device = DeviceRef.GPU()
print(gpu_device) # Outputs: gpu:0
# Create a CPU device with specific id
cpu_device = DeviceRef.CPU(id=1)
print(cpu_device) # Outputs: cpu:1-
Parameters:
-
- device_type (DeviceKind)
- id (int)
CPU()
static CPU(id=0)
Creates a CPU device reference.
GPU()
static GPU(id=0)
Creates a GPU device reference.
device_type
device_type: DeviceKind
from_device()
static from_device(device)
Converts a Device or DeviceRef to a DeviceRef.
from_mlir()
static from_mlir(attr)
Returns a device reference from an MLIR attribute.
-
Parameters:
-
attr (DeviceRefAttr)
-
Return type:
id
id: int
is_cpu()
is_cpu()
Returns True if the device is a CPU device.
-
Return type:
is_gpu()
is_gpu()
Returns True if the device is a GPU device.
-
Return type:
to_device()
to_device()
Converts a device reference to a concrete driver Device.
-
Return type:
to_mlir()
to_mlir()
Returns an MLIR attribute representing the device.
-
Return type:
-
DeviceRefAttr
Dim
class max.nn.kernels.Dim(value)
Bases: object
A tensor dimension.
Dims describe the shape of tensors in a Graph. In most cases, you don’t
need to construct a Dim directly. Instead, you pass dimension values
directly to TensorType or BufferType constructors:
from max.graph import Dim, TensorType, DeviceRef
# Create a TensorType with a symbolic "batch" dimension and a static dimension of size 10
tensor_type = TensorType(DType.int64, ("batch", 10), device=DeviceRef.CPU())A tensor dimension can be one of three types:
- Static: A known size. See
StaticDim. - Symbolic: An unknown size identified by name. See
SymbolicDim. - Algebraic: An expression derived from symbolic dimensions. See
AlgebraicDim.
Static dimensions let the graph compiler resolve shapes at compile time. This enables more aggressive optimizations than symbolic or algebraic dimensions allow. That said, when tensors share a named symbolic dimension, the compiler can leverage the implied shape equality to optimize some operations.
Converts valid input values to Dim.
-
Parameters:
-
value (DimLike)
from_mlir()
static from_mlir(attr)
Constructs a dimension from an mlir.Attribute.
-
Parameters:
-
attr (TypedAttr) – The MLIR Attribute to parse into a dimension.
-
Returns:
-
The dimension represented by the MLIR Attr value.
-
Return type:
parameters
property parameters: Iterable[SymbolicDim]
Lists the symbolic dimension names on which this dim depends.
to_mlir()
to_mlir()
Creates an mlir.Attribute representing this dimension.
This is used internally when constructing tensor MLIR types.
-
Returns:
-
An
mlir.Attributein the context representing the dimension. -
Return type:
-
TypedAttr
InputScaleSpec
class max.nn.kernels.InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None, block_size=None)
Bases: object
Specifies how input activations are scaled for scaled quantization.
-
Parameters:
-
- granularity (ScaleGranularity)
- origin (ScaleOrigin)
- dtype (DType)
- activation_scale_ub (float | None)
- block_size (tuple[int, int] | None)
activation_scale_ub
An optional upper bound for dynamic activation scaling.
block_size
The tuple[int, int] of the block size for block-wise scaling.
dtype
dtype: DType
The DType of the input scale factor(s).
granularity
granularity: ScaleGranularity
The ScaleGranularity of the input scale factor application.
is_block
property is_block: bool
Whether the input scale granularity is block-wise.
is_colwise
property is_colwise: bool
Whether the input scale granularity is column-wise.
is_rowwise
property is_rowwise: bool
Whether the input scale granularity is row-wise.
is_tensor
property is_tensor: bool
Whether the input scale granularity is per-tensor.
origin
origin: ScaleOrigin
The ScaleOrigin (static or dynamic) of the input scale factor.
KVCacheParams
class max.nn.kernels.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, devices, enable_prefix_caching=False, kv_connector=None, kv_connector_config=None, host_kvcache_swap_space_gb=None, page_size=128, is_mla=False, num_q_heads=None, data_parallel_degree=1, n_kv_heads_per_device=0, num_q_heads_per_device=None, kvcache_quant_config=None, num_eagle_speculative_tokens=0)
Bases: KVCacheParamInterface
Configuration parameters for key-value cache management in transformer models.
This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, and memory management.
-
Parameters:
-
- dtype (DType)
- n_kv_heads (int)
- head_dim (int)
- num_layers (int)
- devices (Sequence[DeviceRef])
- enable_prefix_caching (bool)
- kv_connector (KVConnectorType | None)
- kv_connector_config (Any)
- host_kvcache_swap_space_gb (float | None)
- page_size (int)
- is_mla (bool)
- num_q_heads (int | None)
- data_parallel_degree (int)
- n_kv_heads_per_device (int)
- num_q_heads_per_device (int | None)
- kvcache_quant_config (KVCacheQuantizationConfig | None)
- num_eagle_speculative_tokens (int)
allocate_buffers()
allocate_buffers(total_num_pages)
Allocates the buffers for the KV cache.
-
Parameters:
-
total_num_pages (int)
-
Return type:
bytes_per_block
property bytes_per_block: int
Returns the number of bytes per cache block.
When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.
-
Returns:
-
The number of bytes per cache block.
copy_as_dp_1()
copy_as_dp_1(replica_idx=0)
Creates a copy of the KVCacheParams with data parallelism disabled.
This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.
-
Returns:
-
A new KVCacheParams instance with data_parallel_degree set to 1.
-
Raises:
-
ValueError – If n_devices is not evenly divisible by data_parallel_degree.
-
Parameters:
-
replica_idx (int)
-
Return type:
data_parallel_degree
data_parallel_degree: int = 1
Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).
devices
Devices to use for the KV cache.
dtype
dtype: DType
Data type for storing key and value tensors in the cache.
dtype_shorthand
property dtype_shorthand: str
Returns a shorthand textual representation of the data type.
-
Returns:
-
“bf16” for bfloat16 dtype, “f32” otherwise.
enable_prefix_caching
enable_prefix_caching: bool = False
Whether to enable prefix caching for efficient reuse of common prompt prefixes.
get_symbolic_inputs()
get_symbolic_inputs(prefix='')
Computes the symbolic inputs for the KV cache.
This method returns a list of KVCacheInputs for each replica. This is used when constructing the model graph.
-
Returns:
-
The symbolic inputs for the KV cache.
-
Parameters:
-
prefix (str)
-
Return type:
head_dim
head_dim: int
Dimensionality of each attention head.
host_kvcache_swap_space_gb
Amount of host memory (in GB) to reserve for KV cache swapping. Required when local or tiered connector is used.
is_fp8_kv_dtype
property is_fp8_kv_dtype: bool
Whether the KV cache stores FP8 data, for dispatch resolution.
Unlike quantized_kv_cache (which also requires valid scale config),
this checks only the storage dtype—matching the compile-time detection
in the MLA decode kernel.
TODO(SERVOPT-1094): Once SnapMLA uses a valid scale_dtype, this
can be replaced by quantized_kv_cache.
is_mla
is_mla: bool = False
Whether the model uses Multi-Latent Attention (MLA) architecture.
kv_connector
kv_connector: KVConnectorType | None = None
Type of KV cache connector to use (null, local, tiered, lmcache).
kv_connector_config
kv_connector_config: Any = None
Connector-specific configuration (KVConnectorConfig from the pipelines layer).
kvcache_quant_config
kvcache_quant_config: KVCacheQuantizationConfig | None = None
KVCache quantization config. Currently only FP8 quantization supported.
n_devices
property n_devices: int
Returns the number of devices.
-
Returns:
-
The number of devices.
n_kv_heads
n_kv_heads: int
Total number of key-value attention heads across all devices.
n_kv_heads_per_device
n_kv_heads_per_device: int = 0
Number of KV heads allocated to each device. Computed automatically in __post_init__.
num_eagle_speculative_tokens
num_eagle_speculative_tokens: int = 0
Number of draft tokens to generate for EAGLE speculative decoding.
num_layers
num_layers: int
Number of layers in the model.
num_q_heads
Number of query attention heads. Required when is_mla is True so
that the attention dispatch resolver can call the MLA-specific kernel.
num_q_heads_per_device
Number of query heads per device. Computed automatically in __post_init__
from num_q_heads and the parallelism configuration.
page_size
page_size: int = 128
Number of tokens per page (block).
This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.
Current constraints: the page size must be a multiple of 128 and at least 128.
quantized_kv_cache
property quantized_kv_cache: bool
Returns whether FP8 KV cache quantization is enabled.
-
Returns:
-
Truewhen the cache dtype isfloat8_e4m3fnorfloat8_e4m3fnuzand a valid quantization scale dtype is configured;Falseotherwise.
shape_per_block
Returns the shape of each cache block.
-
Returns:
-
The shape of the cache block.
shape_per_scale_block
Returns the shape of each scale block used for KVCache quantization
-
Returns:
-
The shape of the KVCache quantization scales block.
tensor_parallel_degree
property tensor_parallel_degree: int
Returns the tensor parallel degree.
-
Returns:
-
The tensor parallel degree.
MHAMaskVariant
class max.nn.kernels.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Defines the integer mask variant codes used by multihead attention kernels.
CAUSAL_MASK
CAUSAL_MASK = '0'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = '3'
NULL_MASK
NULL_MASK = '2'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = '4'
MutableSequence
class max.nn.kernels.MutableSequence
Bases: Sequence
All the operations on a read-write sequence.
Concrete subclasses must provide __new__ or __init__, __getitem__, __setitem__, __delitem__, __len__, and insert().
append()
append(value)
S.append(value) – append value to the end of the sequence
clear()
clear() → None -- remove all items from S
extend()
extend(values)
S.extend(iterable) – extend sequence by appending elements from the iterable
insert()
abstract insert(index, value)
S.insert(index, value) – insert value before index
pop()
pop() → item -- remove and return item at index (default last).
Raise IndexError if list is empty or index is out of range.
remove()
remove(value)
S.remove(value) – remove first occurrence of value. Raise ValueError if the value is not present.
reverse()
reverse()
S.reverse() – reverse IN PLACE
QuantConfig
class max.nn.kernels.QuantConfig(input_scale, weight_scale, mlp_quantized_layers, attn_quantized_layers, format, embedding_output_dtype=None, bias_dtype=None, can_use_fused_mlp=False, scales_pre_interleaved=False)
Bases: object
Configures scaled quantization settings for a layer or model section.
For example, to configure NVFP4 block-scaled quantization for all layers in a 19-layer model:
from max.dtype import DType
from max.nn import QuantConfig, QuantFormat
from max.nn.quant_config import (
InputScaleSpec,
ScaleGranularity,
ScaleOrigin,
WeightScaleSpec,
)
all_layers = set(range(19))
input_spec = InputScaleSpec(
granularity=ScaleGranularity.BLOCK,
origin=ScaleOrigin.STATIC,
dtype=DType.float32,
block_size=(1, 16),
)
weight_spec = WeightScaleSpec(
granularity=ScaleGranularity.BLOCK,
dtype=DType.float8_e4m3fn,
block_size=(1, 8),
)
config = QuantConfig(
input_scale=input_spec,
weight_scale=weight_spec,
mlp_quantized_layers=all_layers,
attn_quantized_layers=all_layers,
format=QuantFormat.NVFP4,
)-
Parameters:
-
- input_scale (InputScaleSpec)
- weight_scale (WeightScaleSpec)
- mlp_quantized_layers (set[int])
- attn_quantized_layers (set[int])
- format (QuantFormat)
- embedding_output_dtype (DType | None)
- bias_dtype (DType | None)
- can_use_fused_mlp (bool)
- scales_pre_interleaved (bool)
attn_quantized_layers
Set of layer indices with quantized attention projections.
Attention projections are quantized on an all-or-nothing basis per layer:
either all of q_proj, k_proj, v_proj, and o_proj are
quantized, or all four remain in bfloat16.
bias_dtype
The DType of bias weights.
can_use_fused_mlp
can_use_fused_mlp: bool = False
Whether the quantization scales can be used with fused MLP operations.
embedding_output_dtype
The DType of the output from the embedding layer.
format
format: QuantFormat
The QuantFormat identifying the quantization format.
input_scale
input_scale: InputScaleSpec
InputScaleSpec for input activation scaling.
is_dynamic
property is_dynamic: bool
True if this input scale is dynamic.
is_fp4
property is_fp4: bool
True if this config represents any FP4 variant (NVFP4 or MXFP4).
is_mxfp4
property is_mxfp4: bool
Returns True if this config represents MXFP4 quantization.
is_nvfp4
property is_nvfp4: bool
True if this config represents modelopt NVFP4.
is_static
property is_static: bool
True if this input scale is static.
mlp_quantized_layers
Set of layer indices with quantized MLPs.
MLPs are quantized on an all-or-nothing basis per layer: either all of
gate_proj, down_proj, and up_proj are quantized, or all three
remain in bfloat16.
quantized_scales_type()
quantized_scales_type(quantized_shape, device_ref)
The TensorType of the scales tensor after dynamic quantization.
-
Parameters:
-
Return type:
scales_granularity_mnk
The weight and input scale granularities on the M, N, and K axes.
scales_pre_interleaved
scales_pre_interleaved: bool = False
Whether weight scales in the checkpoint are already stored in the 5D
TCGEN-interleaved layout expected by the FP4 matmul kernel (NVFP4 only).
Note that scales in the 5D TCGEN-interleaved layout are typically flattened
to 2D [M, K//16] in the checkpoint.
weight_scale
weight_scale: WeightScaleSpec
WeightScaleSpec for weight scaling.
QuantizationConfig
class max.nn.kernels.QuantizationConfig(quant_method, bits, group_size, desc_act=False, sym=False)
Bases: object
Configuration for specifying quantization parameters that affect inference.
These parameters control how tensor values are quantized, including the method, bit precision, grouping, and other characteristics that affect the trade-off between model size, inference speed, and accuracy.
bits
bits: int
The number of bits used to represent each quantized weight element.
desc_act
desc_act: bool = False
Whether to use activation ordering (descending activation order). Defaults to False.
group_size
group_size: int
The number of weight elements that share a single set of quantization parameters.
quant_method
quant_method: str
The quantization method name (for example, gptq or awq).
sym
sym: bool = False
Whether to use symmetric quantization. Defaults to False.
QuantizationEncoding
class max.nn.kernels.QuantizationEncoding(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Bases: Enum
Quantization encodings supported by MAX Graph.
Quantization reduces the precision of neural network weights to decrease memory usage and potentially improve inference speed. Each encoding represents a different compression method with specific trade-offs between model size, accuracy, and computational efficiency. These encodings are commonly used with pre-quantized model checkpoints (especially GGUF format) or can be applied during weight allocation.
The following example shows how to create a quantized weight using the Q4_K encoding:
from max.graph.quantization import QuantizationEncoding
from max.graph import Weight
# Create a quantized weight using Q4_K encoding
encoding = QuantizationEncoding.Q4_K
quantized_weight = Weight(
name="linear.weight",
dtype=DType.uint8,
shape=[4096, 4096],
device=DeviceRef.GPU(0),
quantization_encoding=encoding
)MAX supports several quantization formats optimized for different use cases.
Q4_0
Q4_0
Basic 4-bit quantization with 32 elements per block.
Q4_K
Q4_K
4-bit K-quantization with 256 elements per block.
Q5_K
Q5_K
5-bit K-quantization with 256 elements per block.
Q6_K
Q6_K
6-bit K-quantization with 256 elements per block.
GPTQ
GPTQ
Group-wise Post-Training Quantization for large language models.
GPTQ
GPTQ = 'GPTQ'
Q4_0
Q4_0 = 'Q4_0'
Q4_K
Q4_K = 'Q4_K'
Q5_K
Q5_K = 'Q5_K'
Q6_K
Q6_K = 'Q6_K'
block_parameters
property block_parameters: BlockParameters
Gets the block parameters for this quantization encoding.
-
Returns:
-
The parameters describing how elements are organized and encoded in blocks for this quantization encoding.
-
Return type:
block_size
property block_size: int
Number of bytes in encoded representation of block.
All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of bytes resulting after encoding a single block.
-
Returns:
-
Size in bytes of each encoded quantization block.
-
Return type:
elements_per_block
property elements_per_block: int
Number of elements per block.
All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of elements gathered into a block.
-
Returns:
-
Number of original tensor elements in each quantized block.
-
Return type:
is_gguf
property is_gguf: bool
Checks if this quantization encoding is compatible with GGUF format.
GGUF is a format for storing large language models and compatible quantized weights.
-
Returns:
-
True if this encoding is compatible with GGUF, False otherwise.
-
Return type:
name
property name: str
Gets the lowercase name of the quantization encoding.
-
Returns:
-
Lowercase string representation of the quantization encoding.
-
Return type:
StaticDim
class max.nn.kernels.StaticDim(value)
Bases: Dim
A static tensor dimension with a fixed size.
Because a static dimension’s size is fixed, related computation can be optimized at compile time. This is key to good model performance.
The following example creates static dimensions implicitly by passing
integer values to TensorType:
from max.graph import TensorType
from max.dtype import DType
tensor = TensorType(DType.int64, (4, 5))
# This creates a tensor with 2 static dimensions: 4 and 5 respectivelyConverts valid input values to Dim.
-
Parameters:
-
dim (int)
dim
dim: int
The size of the static dimension.
from_mlir()
static from_mlir(attr)
Constructs a StaticDim from a builtin.IntegerAttr.
-
Parameters:
-
attr (TypedAttr) – The
builtin.IntegerAttrto parse into aStaticDim. -
Returns:
-
The
StaticDimrepresented by thebuiltin.IntegerAttr. -
Return type:
parameters
property parameters: Iterable[SymbolicDim]
Lists the symbolic dimension names on which this dim depends.
to_mlir()
to_mlir()
Creates an mlir.Attribute representing this dimension.
This is used internally when constructing tensor MLIR types.
-
Returns:
-
An
mlir.Attributein the context representing the dimension. -
Return type:
-
IntegerAttr
TensorType
class max.nn.kernels.TensorType(dtype, shape, device, _layout=None)
Bases: _TensorTypeBase[TensorType]
A symbolic tensor type.
Use TensorType to declare the expected dtype, shape, and target
device of tensor values that flow through a graph during model
execution. Unlike an eager tensor, a TensorType holds no data. It is a
purely symbolic description of a value’s type at a specific point in the
computation. The graph compiler uses this information for shape inference
and optimization during graph construction.
The following example shows how to create a tensor type and access its properties:
from max.graph import TensorType, DeviceRef
from max.dtype import DType
# Create a tensor type with float32 elements and static dimensions 2x3
tensor_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
print(tensor_type.dtype) # Outputs: DType.float32
print(tensor_type.shape) # Outputs: [2, 3]A shape’s dimensions can be static (integers), symbolic (strings), or algebraic (expressions over symbolic dimensions). In each case the rank is known at graph construction time.
Pass TensorType instances to load()
or Module.compile() (experimental) to define the input types of a
graph or model.
-
Parameters:
-
- dtype (DType) – The data type of the tensor elements.
- shape (Shape) – The shape of the tensor, expressed as a
Shape. - device (DeviceRef) – The device the tensor is located on. Use
DeviceRef.CPU()orDeviceRef.GPU()to create a device reference. - _layout (FilterLayout | None)
as_buffer()
as_buffer()
Returns the analogous buffer type.
-
Return type:
from_mlir()
classmethod from_mlir(type)
Constructs a tensor type from an MLIR type.
-
Parameters:
-
type (TensorType) – The MLIR Type to parse into a tensor type.
-
Returns:
-
The tensor type represented by the MLIR Type value.
-
Return type:
to_mlir()
to_mlir()
Converts to an mlir.Type instance.
-
Returns:
-
An
mlir.Typein the specified context. -
Return type:
-
TensorType
TensorValue
class max.nn.kernels.TensorValue(value)
Bases: Value[TensorType]
Represents a value semantic tensor within a Graph.
It provides various methods and properties to manipulate and query tensor
attributes such as shape, data type (dtype), device placement
(device), and more.
The following example demonstrates how to create and manipulate tensor values in a graph:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a sample matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("tensor_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor properties
print(f"Shape: {tensor.shape}") # Output: [2, 2]
print(f"Data type: {tensor.dtype}") # Output: DType.float32
# Perform operations on the tensor
transposed = tensor.T
doubled = tensor * 2
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Transposed shape: {transposed.shape}") # Output: [2, 2]Initializes a TensorValue from a tensor-like value.
-
Parameters:
-
value (TensorValueLike) – The value to wrap. Can be an MLIR tensor value, another
TensorValue, aDim, or aShape.
T
property T: TensorValue
Returns the transposed tensor.
T is the shorthand notation for transposing.
For more information, see transpose().
-
Returns:
-
A new
TensorValuewith swapped dimensions.
argmax()
argmax(axis=-1)
Reduces the tensor using an argmax operation along axis.
When the result is ambiguous ie. there are multiple maxima, selects one index arbitrarily.
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("argmax_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Argmax along axis 1 (last dimension of each row)
indices = x.argmax(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Argmax shape: {indices.shape}") # [2, 1]-
Parameters:
-
axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example,
-1is the last dimension). -
Returns:
-
A
TensorValueof dtypeDType.int64with the same rank as the input, and the same shape except alongaxis, which will have size 1. -
Return type:
broadcast_to()
broadcast_to(shape)
Broadcasts the tensor to a new shape.
The following example demonstrates how to broadcast a tensor to a larger shape:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("broadcast_to_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Broadcast tensor to a 3x2x2 tensor (add a new dimension of size 3)
broadcasted_tensor = tensor.broadcast_to((3, 2, 2))
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Broadcasted shape: {broadcasted_tensor.shape}") # Output: [3, 2, 2]cast()
cast(dtype)
Casts a symbolic tensor to a different data type.
The following example demonstrates how to cast a tensor from one data type to another:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a matrix with float32 values
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("cast_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Cast tensor to integer type
casted_tensor = tensor.cast(DType.int32)
print(f"Original dtype: {tensor.dtype}") # Output: DType.float32
print(f"Casted dtype: {casted_tensor.dtype}") # Output: DType.int32-
Parameters:
-
dtype (DType) – The target data type (for example,
DType.int32,DType.float64). -
Returns:
-
A new
TensorValuewith the casted data type. -
Return type:
device
property device: DeviceRef
Returns the device of the TensorValue.
dtype
property dtype: DType
Returns the tensor data type.
The following example demonstrates how to access the data type of a tensor:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a matrix with float32 values
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("dtype_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor data type
print(f"Data type: {tensor.dtype}") # Output: DType.float32flatten()
flatten(start_dim=0, end_dim=-1)
Flattens the specified dims of a symbolic tensor.
The number and order of the elements in the tensor is unchanged.
All dimensions from start_dim to end_dim (inclusive) are merged into a single output dim.
The following example demonstrates how to flatten a multi-dimensional tensor:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("flatten_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Flatten the tensor to a 1D array
flattened_tensor = tensor.flatten()
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Flattened shape: {flattened_tensor.shape}") # Output: [4]-
Parameters:
-
Returns:
-
A new
TensorValuewith the flattened dimensions. -
Return type:
from_mlir()
classmethod from_mlir(value)
Creates a TensorValue from an MLIR tensor value.
-
Parameters:
-
value (Value[TensorType]) – The MLIR tensor value to wrap.
-
Return type:
max()
max(axis=-1)
Reduces the tensor using a max operation along axis.
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("max_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Max along axis 1 (last dimension of each row)
m = x.max(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Max shape: {m.shape}") # [2, 1]-
Parameters:
-
axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example,
-1is the last dimension). -
Returns:
-
A
TensorValuewith the same rank as the input and the same shape except alongaxis, which will have size 1. -
Return type:
mean()
mean(axis=-1)
Reduces the tensor using a mean operation along axis.
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("mean_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Mean along axis 1 (last dimension of each row)
mu = x.mean(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Mean shape: {mu.shape}") # [2, 1]-
Parameters:
-
axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example,
-1is the last dimension). -
Returns:
-
A
TensorValuewith the same rank as the input and the same shape except alongaxis, which will have size 1. -
Return type:
min()
min(axis=-1)
Reduces the tensor using a min operation along axis.
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("min_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Min along axis 1 (last dimension of each row)
mn = x.min(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Min shape: {mn.shape}") # [2, 1]-
Parameters:
-
axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example,
-1is the last dimension). -
Returns:
-
A
TensorValuewith the same rank as the input and the same shape except alongaxis, which will have size 1. -
Return type:
permute()
permute(dims)
Permutes the tensor’s dimensions based on provided indices.
-
Parameters:
-
dims (list[int]) – A list of integers specifying the new order of dimensions.
-
Returns:
-
A new
TensorValuewith permuted dimensions. -
Return type:
print()
print(label='debug_tensor')
Prints detailed information about the tensor.
-
Parameters:
-
label (str) – A string label for the printed output. Defaults to
debug_tensor. -
Return type:
-
None
rank
property rank: int
Returns the rank (number of dims) of the buffer.
The following example demonstrates how to access the rank of a tensor:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix (2-dimensional array)
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("rank_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor rank (number of dimensions)
print(f"Rank: {tensor.rank}") # Output: 2rebind()
rebind(shape, message='')
Rebinds the tensor to a new shape with error handling.
-
Parameters:
-
Returns:
-
A new
TensorValuewith the updated shape. -
Return type:
reshape()
reshape(shape)
Creates a new tensor with the same data but reshaped.
The following example demonstrates how to reshape a tensor to change its dimensions:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("reshape_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Reshape tensor to a 1x4 matrix
reshaped_tensor = tensor.reshape((1, 4))
print(f"Original shape: {tensor.shape}") # Output: [2, 2]
print(f"Reshaped shape: {reshaped_tensor.shape}") # Output: [1, 4]shape
property shape: Shape
Returns the shape of the TensorValue.
The following example demonstrates how to access the shape of a tensor:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("shape_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Access tensor shape
print(f"Shape: {tensor.shape}") # Shape: [Dim(2), Dim(2)]stdev()
stdev(axis=-1)
Reduces the tensor using a standard deviation operation along axis.
The standard deviation is computed as the square root of the population variance along the specified axis.
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("stdev_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Standard deviation along axis 1 (last dimension of each row)
sd = x.stdev(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Stdev shape: {sd.shape}") # [2, 1]-
Parameters:
-
axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example,
-1is the last dimension). -
Returns:
-
A
TensorValuewith the same rank as the input and the same shape except alongaxis, which will have size 1. -
Return type:
to()
to(device)
Inserts a graph-level transfer to device into the compiled graph.
This is a graph execution-time operation: it records a transfer
node during graph tracing that moves this symbolic tensor to device
when the compiled graph runs. It is equivalent to calling
transfer_to() and is typically used inside
forward() to route activation tensors between devices.
This is distinct from to(), which is a
pre-compilation host-side operation that moves stored weight
tensors before the graph is built. If you want to place a module’s
weights and computation on a device, use Module.to(device) before
calling compile().
The following example demonstrates how to move a tensor from one device to another:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops, DeviceRef
# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
with Graph("to_device_example") as graph:
# Create a tensor on the default device
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Move the tensor to a GPU device
gpu_tensor = tensor.to(DeviceRef.GPU())
print(f"Original device: {tensor.device}") # Output depends on default device
print(f"New device: {gpu_tensor.device}") # Output: gpu:0-
Parameters:
-
device (DeviceRef) – A
DeviceRefobject specifying the target device. -
Returns:
-
A new
TensorValueon the specified device. -
Return type:
transpose()
transpose(dim_1, dim_2)
Swaps two dimensions of the tensor.
The following example demonstrates how to transpose a tensor by swapping its dimensions:
import numpy as np
from max.dtype import DType
from max.graph import Graph, ops
# Create a 2x3 matrix
matrix = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
with Graph("transpose_demo") as graph:
tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())
# Transpose the tensor (swap dimensions 0 and 1)
transposed_tensor = tensor.transpose(dim_1=0, dim_2=1)
print(f"Original shape: {tensor.shape}") # Output: [2, 3]
print(f"Transposed shape: {transposed_tensor.shape}") # Output: [3, 2]-
Parameters:
-
Returns:
-
A new
TensorValuewith swapped dimensions. -
Return type:
type
property type: TensorType
Returns the type of the TensorValue as a TensorType.
var()
var(axis=-1)
Reduces the tensor using a variance operation along axis.
The variance is computed as the mean of squared deviations from the mean (population variance, i.e., without Bessel’s correction) along the specified axis.
from max.dtype import DType
from max.graph import Graph, TensorType, DeviceRef
# Define a 2x3 float32 input tensor for the graph
input_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
with Graph("var_demo", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
# Variance along axis 1 (last dimension of each row)
vr = x.var(axis=1)
print(f"Input shape: {x.shape}") # [2, 3]
print(f"Var shape: {vr.shape}") # [2, 1]-
Parameters:
-
axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension (for example,
-1is the last dimension). -
Returns:
-
A
TensorValuewith the same rank as the input and the same shape except alongaxis, which will have size 1. -
Return type:
Type
class max.nn.kernels.Type
Bases: Generic[MlirType]
The type of any value in a MAX graph.
Every value in the graph has a type, and that type is represented by a Type.
This type may be inspected to get finer-grained types and learn more
about an individual Value.
The following example shows how to work with types in a graph:
from max.graph import Graph, TensorType
from max.dtype import DType
with Graph() as g:
# Create a tensor constant with a specific type
tensor_type = TensorType(DType.float32, [2, 3])
# The type can be inspected to get information about the value
print(f"Tensor element type: {tensor_type.dtype}") # Outputs: DType.float32
print(f"Tensor shape: {tensor_type.shape}") # Outputs: [2, 3]from_mlir()
static from_mlir(t)
Constructs a type from an MLIR type.
to_mlir()
to_mlir()
Converts to an mlir.Type instance.
-
Returns:
-
An
mlir.Typein the specified Context. -
Return type:
-
MlirType
Value
class max.nn.kernels.Value
Bases: Generic[MlirType]
Represents a symbolic value within a Graph.
A Value can represent the output of a node, the arguments of a
Graph (as seen from within its body), and more generally any symbolic
value available within the Graph. Other nodes receive Value
values as inputs to form a computation graph.
A Value may also refer to an existing input or output of a node,
and you can change them, such as by swapping a new Value.
Conceptually, think of a Value as an edge in the dataflow graph,
with the other end being the user of that value.
The following example shows how to work with Values in a graph to create a simple computation:
from max.graph import Graph, ops, Value
from max.dtype import DType
import numpy as np
# Create a graph context
with Graph("value_example") as graph:
# Create input values
a = ops.constant(np.array([1, 2, 3]), dtype=DType.float32, device=DeviceRef.CPU())
b = ops.constant(np.array([4, 5, 6]), dtype=DType.float32, device=DeviceRef.CPU())
# Use values to perform operations
c = a + b # c is a Value representing the addition
# Demonstrate that the result is a Value
print(f"Type of c: {type(c)}")
print(f"Is c a Value? {isinstance(c, Value)}")Similar to a regular variable, a Value has a data type.
Value is abstract, it shouldn’t be constructed directly.
buffer
property buffer: BufferValue
Returns the Value as a BufferValue.
Raises an exception if the Value is not a BufferValue.
from_mlir()
classmethod from_mlir(value)
Creates a Value from an MLIR value.
opaque
property opaque: _OpaqueValue
Returns the Value as an _OpaqueValue.
Raises an exception if the Value is not a _OpaqueValue.
tensor
property tensor: TensorValue
Returns the Value as a TensorValue.
Raises an exception if the Value is not a TensorValue.
to_mlir()
to_mlir()
Converts the Value to an MLIR value.
-
Return type:
-
Value[MlirType]
type
property type: Type[MlirType]
WeightScaleSpec
class max.nn.kernels.WeightScaleSpec(granularity, dtype, block_size=None)
Bases: object
Specifies how weights are scaled for scaled quantization.
-
Parameters:
-
- granularity (ScaleGranularity)
- dtype (DType)
- block_size (tuple[int, int] | None)
block_size
The tuple[int, int] of the block size for block-wise scaling.
dtype
dtype: DType
The DType of the weight scale factor(s).
granularity
granularity: ScaleGranularity
The ScaleGranularity of the weight scale factor application.
is_block
property is_block: bool
Whether the weight scale granularity is block-wise.
is_colwise
property is_colwise: bool
Whether the weight scale granularity is column-wise.
is_rowwise
property is_rowwise: bool
Whether the weight scale granularity is row-wise.
is_tensor
property is_tensor: bool
Whether the weight scale granularity is per-tensor.
accelerator_architecture_name()
max.nn.kernels.accelerator_architecture_name()
Returns the architecture name of the accelerator device.
-
Return type:
apply_penalties_to_logits()
max.nn.kernels.apply_penalties_to_logits(logits_buffer, frequency_data, frequency_offsets, *, frequency_penalty=0.0, presence_penalty=0.0, repetition_penalty=1.0)
Applies penalties to the logits.
-
Parameters:
-
- logits_buffer (BufferValue) – The buffer to apply penalties to.
- frequency_data (TensorValue) – 2d tensor of shape [unique_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token.
- frequency_offsets (TensorValue) – 1d tensor of shape [batch_size + 1], indicating start of each sequence’s data.
- frequency_penalty (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text: tokens will receive a penalty proportional to the count of appearances.
- presence_penalty (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The presence penalty to apply to the model’s output A positive value will penalize new tokens that have already appeared in the generated text at least once by applying a constant penalty.
- repetition_penalty (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in prompt and generated text at least once by dividing the logits by the repetition penalty.
-
Return type:
-
None
assert_same_device()
max.nn.kernels.assert_same_device(*values, **named_values)
Raises ValueError if any of the given values are not on the same device.
-
Parameters:
-
- values (TensorValue | BufferValue)
- named_values (TensorValue | BufferValue)
-
Return type:
-
None
batched_dynamic_scaled_fp8_matmul()
max.nn.kernels.batched_dynamic_scaled_fp8_matmul(a, b, a_scales, b_scales, input_scale_spec, weight_scale_spec, out_type=bfloat16)
Performs a batched blockwise scaled matmul of two tensors with scaling factors.
-
Parameters:
-
- a (TensorValue) – The first tensor to multiply (3D tensor).
- b (TensorValue) – The second tensor to multiply, must be transposed (3D tensor).
- a_scales (TensorValue) – The scaling factors for the first tensor (3D tensor).
- b_scales (TensorValue) – The scaling factors for the second tensor (3D tensor).
- input_scale_spec (InputScaleSpec)
- weight_scale_spec (WeightScaleSpec)
- out_type (DType)
-
Returns:
-
The result of the matmul operation.
-
Return type:
block_scales_interleave()
max.nn.kernels.block_scales_interleave(scales, sf_vector_size=16)
Interleaves rank-2 FP4 block scales into the rank-5 TCGEN layout.
-
Parameters:
-
- scales (TensorValue) – Rank-2 block scales in
[M, K // sf_vector_size]layout. Supported dtypes arefloat8_e4m3fnfor NVFP4 andfloat8_e8m0fnufor MXFP4. - sf_vector_size (int) – Scale-factor vector size: 16 for NVFP4 or 32 for MXFP4.
- scales (TensorValue) – Rank-2 block scales in
-
Returns:
-
The interleaved scales tensor in
[ceildiv(M, 128), ceildiv(K // sf_vector_size, 4), 32, 4, 4]layout. -
Return type:
ceildiv()
max.nn.kernels.ceildiv(n, d)
Ceiling division.
compute_mha_decode_num_partitions()
max.nn.kernels.compute_mha_decode_num_partitions(batch_size, max_cache_valid_length, n_kv_heads, device)
Computes the MHA decode partition count inside a graph.
Wraps the mo.mha.decode.get_num_partitions kernel as a graph op so
that the partition heuristic can be evaluated dynamically during graph
execution rather than only at graph-build time.
-
Parameters:
-
- batch_size (TensorValue) – Scalar int64 tensor with the current batch size.
- max_cache_valid_length (TensorValue) – Scalar int64 tensor with the maximum valid cache length across all requests.
- n_kv_heads (int) – Number of key-value attention heads per device (compile-time constant).
- device (DeviceRef) – The
DeviceRefwhose hardware info determines the partition heuristic.
-
Returns:
-
A CPU
TensorValueof shape[1]and dtypeint64containing the computed partition count. -
Return type:
compute_mla_dispatch_args_scalar()
max.nn.kernels.compute_mla_dispatch_args_scalar(batch_size, max_cache_valid_length, q_max_seq_len, num_heads, device, is_fp8_kv=False)
Computes scalar dispatch arguments for the MLA decode kernel.
Produces a CPU tensor of shape [3] containing pre-computed integer
arguments used by the capturable MLA decode kernel variant to enable CUDA
graph capture.
-
Parameters:
-
- batch_size (TensorValue) – Scalar tensor indicating the current batch size.
- max_cache_valid_length (TensorValue) – Scalar tensor with the maximum valid cache sequence length across all requests in the batch.
- q_max_seq_len (TensorValue) – Scalar tensor with the maximum query sequence length in the current batch.
- num_heads (int) – Number of query attention heads.
- device (DeviceRef) – The
DeviceRefon which to run the op. - is_fp8_kv (bool)
-
Returns:
-
A CPU
TensorValueof shape[3]and dtypeint64containing the dispatch scalar arguments. -
Return type:
convert_weights_to_fp8_fnuz_if_needed()
max.nn.kernels.convert_weights_to_fp8_fnuz_if_needed(weight, weight_scale)
Converts weights and scales to FP8 FNUZ format if needed for AMD GPUs.
This utility function checks if FP8 FNUZ conversion is needed, currently onli AMD MI300 GPUs, and performs the conversion if required. This centralizes the conversion logic that was previously duplicated across multiple files.
-
Parameters:
-
- weight (TensorValue) – The weight tensor to potentially convert.
- weight_scale (TensorValue) – The weight scale factor.
-
Returns:
-
Tuple of (weight, weight_scale) - converted if needed, original otherwise.
-
Return type:
cross_attention_ragged()
max.nn.kernels.cross_attention_ragged(kv_params, input, input_row_offsets, kv_collection, layer_idx, mask_variant, kv_input_row_offsets, q_max_seq_len, scale, local_window_size=-1)
Computes cross attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
attention, kv_input_row_offsets represents the KV sequence length.
-
Parameters:
-
- kv_params (KVCacheParams)
- input (TensorValue)
- input_row_offsets (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
- mask_variant (MHAMaskVariant)
- kv_input_row_offsets (TensorValue)
- q_max_seq_len (TensorValue)
- scale (float)
- local_window_size (int)
-
Return type:
dynamic_block_scaled_matmul_fp4()
max.nn.kernels.dynamic_block_scaled_matmul_fp4(a, b, a_scales, b_scales, tensor_sf, sf_vector_size=16, out_type=bfloat16)
Performs a matmul of two FP4 tensors with 1D-block scaled scaling factors.
-
Parameters:
-
- a (TensorValue) – The first tensor to multiply.
- b (TensorValue) – The second tensor to multiply, must be transposed.
- a_scales (TensorValue) – The scaling factors for the first tensor.
- b_scales (TensorValue) – The scaling factors for the second tensor.
- tensor_sf (TensorValue | float) – Buffer-wise scaling factor equal to weight_scale_2 * input_scale (non-inverted).
- sf_vector_size (int)
- out_type (DType)
-
Returns:
-
The result of the matmul operation.
-
Return type:
dynamic_block_scaled_matmul_mxfp4()
max.nn.kernels.dynamic_block_scaled_matmul_mxfp4(a, b, a_scales, b_scales, out_type=bfloat16)
Performs a matmul of two FP4 tensors with 1D-block scaled scaling factors.
-
Parameters:
-
- a (TensorValue) – The first tensor to multiply.
- b (TensorValue) – The second tensor to multiply, must be transposed.
- a_scales (TensorValue) – The scaling factors for the first tensor.
- b_scales (TensorValue) – The scaling factors for the second tensor.
- out_type (DType)
-
Returns:
-
The result of the matmul operation.
-
Return type:
dynamic_scaled_matmul()
max.nn.kernels.dynamic_scaled_matmul(a, b, a_scales, b_scales, input_scale_spec, weight_scale_spec, out_type=bfloat16)
Performs a matmul of two tensors with scaling factors. Currently only supports channel-wise scaling for weights and per-token scaling for inputs.
-
Parameters:
-
- a (TensorValue) – The first tensor to multiply.
- b (TensorValue) – The second tensor to multiply, must be transposed.
- a_scales (TensorValue) – The scaling factors for the first tensor.
- b_scales (TensorValue) – The scaling factors for the second tensor.
- input_scale_spec (InputScaleSpec)
- weight_scale_spec (WeightScaleSpec)
- out_type (DType)
-
Returns:
-
The result of the matmul operation.
-
Return type:
eagle_prefill_shift_tokens()
max.nn.kernels.eagle_prefill_shift_tokens(tokens, offsets, shift_next_tokens)
Shifts ragged tokens left by 1 per request, appending bonus tokens.
-
Parameters:
-
- tokens (TensorValue) – Flat ragged token sequence of shape
[total_seq_len], dtype int64. - offsets (TensorValue) – Row offsets of shape
[batch_size + 1], dtype uint32. - shift_next_tokens (TensorValue) – One token per request of shape
[batch_size], dtype int64, to append after shifting.
- tokens (TensorValue) – Flat ragged token sequence of shape
-
Returns:
-
Shifted (or copied) tokens with the same shape as
tokens. -
Return type:
flare_mla_decode_ragged()
max.nn.kernels.flare_mla_decode_ragged(kv_params, input, input_row_offsets, kv_collection, layer_idx, mask_variant, scale, scalar_args, *, qk_rope_dim=64)
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
-
Parameters:
-
- kv_params (KVCacheParams)
- input (TensorValue)
- input_row_offsets (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
- mask_variant (MHAMaskVariant)
- scale (float)
- scalar_args (TensorValue)
- qk_rope_dim (int)
-
Return type:
flare_mla_decode_ragged_scaled()
max.nn.kernels.flare_mla_decode_ragged_scaled(kv_params, input, input_row_offsets, kv_collection, kv_scales, q_scales, layer_idx, mask_variant, scale, scalar_args, qk_rope_dim=64, per_token_scale_rope_aware=False, quantization_granularity=640)
MLA decode with explicit per-token KV and Q scale tensors.
Like flare_mla_decode_ragged but accepts explicit scale tensors so the
per-token-scale rope-aware kernel receives real (non-identity) scales.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache parameters.
- input (TensorValue) – Query tensor [total_tokens, num_heads, head_dim].
- input_row_offsets (TensorValue) – Ragged row offsets [batch_size + 1].
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection.
- kv_scales (BufferValue) – Per-token KV scales buffer [num_blocks, 1, 1, page_size, 1, 1] float32.
- q_scales (TensorValue) – Per-token Q scales tensor [total_tokens] float32.
- layer_idx (TensorValue) – Layer index (uint32, on CPU).
- mask_variant (MHAMaskVariant) – Attention mask variant.
- scale (float) – Softmax scale (typically 1/sqrt(d_qk)).
- qk_rope_dim (int) – Rope head dimension (default 64).
- per_token_scale_rope_aware (bool) – Use FP8+BF16 interleaved layout.
- quantization_granularity (int) – Granularity for KV scale quantization. Should equal the KV cache head_dim (640 for rope-aware).
- scalar_args (TensorValue)
-
Returns:
-
Output tensor [total_tokens, num_heads, output_dim].
-
Return type:
flare_mla_decompress_k_cache()
max.nn.kernels.flare_mla_decompress_k_cache(kv_params, buffer_row_offsets_1d, cache_offsets_1d, buffer_length, weight, kv_collection, layer_idx, buffer_size)
This kernel decompresses the key cache by up-projecting latent representations into the KV space using a weight matrix.
The process involves:
- Copying buffer_length latent vectors from the key cache into a contiguous buffer (k_latent)
- Computing k = k_latent @ weight.T to obtain the decompressed keys
-
Returns:
-
A tensor of shape [buffer_size, weight.shape[0]] containing the decompressed keys. Note that only the first buffer_length tokens are valid.
-
Parameters:
-
- kv_params (KVCacheParams)
- buffer_row_offsets_1d (TensorValue)
- cache_offsets_1d (TensorValue)
- buffer_length (TensorValue)
- weight (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
- buffer_size (int)
-
Return type:
flare_mla_prefill_plan()
max.nn.kernels.flare_mla_prefill_plan(kv_params, input_row_offsets, kv_collection, layer_idx, buffer_size, max_chunks=16)
This kernel plans how to process a batch of sequences with varying lengths using a fixed-size buffer.
Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_size.
- For each chunk (iteration), it calculates:
- Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration
-
Parameters:
-
- kv_params (KVCacheParams)
- input_row_offsets (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
- buffer_size (int)
- max_chunks (int)
-
Return type:
flare_mla_prefill_ragged()
max.nn.kernels.flare_mla_prefill_ragged(kv_params, input, k, v, input_row_offsets, buffer_row_offsets, cache_offsets, kv_collection, layer_idx, mask_variant, scale, qk_rope_dim=64)
Performs MLA prefill. In the MLA prefill, we need to decompress the KV tensors, as we store the latent representations in the KV cache. We will decompress the KV tensors into a fixed size buffer to avoid out-of-memory errors. In case the total cache length is greater than the buffer size, we will process the attention calculation in chunks.
This MLA prefill kernel will return the output tensor for this iteration and the softmax info tensor for this iteration. Such tensors will be used by the next iteration of the MLA prefill kernel to continue the attention calculation.
-
Parameters:
-
- kv_params (KVCacheParams) – KVCacheParams
- input (TensorValue) – Input tensor
- k (TensorValue) – Key tensor
- v (TensorValue) – Value tensor
- input_row_offsets (TensorValue) – Indicates where each batch starts and ends in input
- buffer_row_offsets (TensorValue) – Indicates where each batch starts and ends in the buffer
- cache_offsets (TensorValue) – Indicates where each batch starts and ends in the KV cache
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – KV collection
- layer_idx (TensorValue) – Layer index tensor
- mask_variant (MHAMaskVariant) – Mask variant
- scale (float) – Scale
- qk_rope_dim (int) – QK rope dimension
-
Returns:
-
The output tensor for this iteration
-
Return type:
flash_attention_gpu()
max.nn.kernels.flash_attention_gpu(q, k, v, mask_variant, scale, local_window_size=-1, valid_length=None)
Computes flash attention using GPU-optimized kernel.
-
Parameters:
-
- q (TensorValue) – Query tensor of shape [batch, seq_len, num_heads, head_dim]
- k (TensorValue) – Key tensor of shape [batch, seq_len, num_heads, head_dim]
- v (TensorValue) – Value tensor of shape [batch, seq_len, num_heads, head_dim]
- mask_variant (MHAMaskVariant) – The mask variant to use for attention
- scale (float) – Scaling factor for attention scores
- local_window_size (int) – Local window size for sliding window attention
- valid_length (TensorValue | None) – Optional tensor of shape [batch] with dtype uint32. When provided, uses the padded kernel variant that respects the valid sequence lengths for each batch element.
-
Returns:
-
Output tensor of shape [batch, seq_len, num_heads, head_dim]
-
Return type:
flash_attention_padded_kv_cache()
max.nn.kernels.flash_attention_padded_kv_cache(kv_params, q, kv_collection, layer_idx, valid_lengths, mask_variant, scale, local_window_size=-1)
Computes flash attention with padded inputs and paged KV cache.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache parameters
- q (TensorValue) – Query tensor of shape [batch, seq_len, num_heads, head_dim]
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection
- layer_idx (TensorValue) – Layer index for cache lookup
- valid_lengths (TensorValue) – Buffer of shape [batch] with dtype uint32 indicating actual (non-padded) sequence lengths for each batch element
- mask_variant (MHAMaskVariant) – The mask variant to use for attention
- scale (float) – Scaling factor for attention scores
- local_window_size (int) – Local window size for sliding window attention
-
Returns:
-
Output tensor of shape [batch, seq_len, num_heads, head_dim]
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
flash_attention_ragged()
max.nn.kernels.flash_attention_ragged(kv_params, input, input_row_offsets, kv_collection, layer_idx, mask_variant, scale, local_window_size=-1, sink_weights=None)
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
-
Parameters:
-
- kv_params (KVCacheParams) – KVCacheParams object containing key-value cache parameters.
- input (TensorValue) – TensorValue representing the input tensor with shape [total_seq_len, hidden_dim].
- input_row_offsets (TensorValue) – TensorValue indicating the start and end of each batch in the input tensor with shape [batch_size + 1].
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – PagedCacheValues object for managing key-value cache.
- layer_idx (TensorValue) – TensorValue representing the layer index, expected to have dtype uint32.
- mask_variant (MHAMaskVariant) – MHAMaskVariant specifying the type of attention mask to use.
- scale (float) – float value used to scale the attention scores.
- local_window_size (int) – int specifying the size of the local attention window, default is -1 for no local window.
- sink_weights (TensorValue | None) – Optional tensor of shape [num_heads] containing learnable sink weights for each attention head.
-
Return type:
flash_attention_ragged_gpu()
max.nn.kernels.flash_attention_ragged_gpu(q, k, v, input_row_offsets, max_seq_len, mask_variant, scale, local_window_size=-1)
Computes flash attention for ragged inputs using GPU-optimized kernel without a KV cache.
-
Parameters:
-
- q (TensorValue) – Query tensor of shape [total_seq_len, num_heads, head_dim] (ragged)
- k (TensorValue) – Key tensor of shape [total_seq_len, num_heads, head_dim] (ragged)
- v (TensorValue) – Value tensor of shape [total_seq_len, num_heads, head_dim] (ragged)
- input_row_offsets (TensorValue) – Buffer of shape [batch_size + 1] with dtype uint32. Indicates where each sequence starts and ends in the ragged tensors. The values should be a prefix sum (cumulative sum) of sequence lengths.
- mask_variant (MHAMaskVariant) – The mask variant to use for attention
- scale (float) – Scaling factor for attention scores
- local_window_size (int) – Local window size for sliding window attention
- max_seq_len (TensorValue)
-
Returns:
-
Output tensor of shape [total_seq_len, num_heads, head_dim]
-
Return type:
fused_qk_padded_rope()
max.nn.kernels.fused_qk_padded_rope(kv_params, input, kv_collection, freqs_cis, layer_idx, valid_lengths, interleaved=True)
Computes fused query-key RoPE with padded inputs and paged KV cache.
This function applies Rotary Positional Embeddings (RoPE) to both Q and K tensors, where K is stored in the paged KV cache. This is the padded equivalent of fused_qk_ragged_rope.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache parameters.
- input (TensorValue) – Query tensor of shape [batch, seq_len, n_heads, head_dim].
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection.
- freqs_cis (TensorValue) – Frequency tensor of shape (max_seq_len * 2, head_dim).
- layer_idx (TensorValue) – Layer index for KV cache (must be uint32 on CPU).
- valid_lengths (TensorValue) – Buffer of shape [batch] containing the valid length for each sequence (must be uint32). RoPE is only applied to positions within these lengths.
- interleaved (bool) – Whether to use interleaved RoPE pattern.
-
Returns:
-
Query tensor with RoPE applied, same shape as input.
-
Return type:
fused_qk_ragged_rope()
max.nn.kernels.fused_qk_ragged_rope(kv_params, input, input_row_offsets, kv_collection, freqs_cis, layer_idx, interleaved=True, position_ids=None, mrope_section=None)
Computes fused query-key attention with rotary positional encodings and ragged inputs.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache parameters
- input (TensorValue) – [batch_size * seq_len, n_heads, head_dim]
- input_row_offsets (TensorValue) – Ragged tensor offsets indicating where each batch starts and ends
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – KV cache collection
- freqs_cis (TensorValue) – tensor of shape (max_seq_len * 2, head_dim)
- layer_idx (TensorValue) – Layer index for KV cache
- interleaved (bool) – Whether to use interleaved RoPE pattern
- position_ids (TensorValue | None) – Optional ragged 2D array of position IDs. If None, defaults to cache_length + token_idx for each token. When num_sections > 1, mrope_section must be provided to indicate each section of the head_dim to apply RoPE to. Shape: [num_sections, total_seq_len]
- mrope_section (list[int] | None) – Optional list of integers indicating the section of the head_dim to
- position_ids. (apply RoPE to. Must be used in conjunction with)
-
Return type:
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input. If input is not of the same dtype as freqs_cis, it will be cast to the dtype of freqs_cis for the computation, and cast back to the original dtype after the computation is finished.
When position_ids and mrope_section are provided, it replaces the default position calculation (cache_length + token_idx) with explicit position values. This is useful for 3D RoPE in models like Qwen2.5-VL that need custom position encoding.
fused_qkv_padded_matmul()
max.nn.kernels.fused_qkv_padded_matmul(kv_params, input, wqkv, kv_collection, layer_idx, valid_lengths, n_heads)
Computes fused query, key, and value projections with padded input.
This is for non-ragged (padded batch) inputs where sequences may have different actual lengths but are padded to a uniform shape.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache parameters.
- input (TensorValue) – Input tensor with shape [batch_size, seq_len, hidden_dim].
- wqkv (TensorValue) – Weight tensor for Q, K, V projections.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection.
- layer_idx (TensorValue) – Layer index for cache lookup (must be uint32).
- valid_lengths (TensorValue) – Buffer of shape [batch] containing the valid length for each sequence (must be uint32). K and V are only written to cache for positions within these lengths.
- n_heads (int) – Number of attention heads.
-
Returns:
-
Query projections tensor. K and V projections are written to cache.
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
fused_qkv_ragged_matmul()
max.nn.kernels.fused_qkv_ragged_matmul(kv_params, input, input_row_offsets, wqkv, kv_collection, layer_idx, n_heads, bias=None, _output_dim=None)
Computes fused query, key, and value projections with ragged input.
-
Parameters:
-
- kv_params (KVCacheParams) – KVCacheParams object containing key-value cache parameters.
- input (TensorValue) – TensorValue representing the input tensor with shape [total_seq_len, hidden_dim].
- input_row_offsets (TensorValue) – TensorValue indicating the start and end of each request in the input tensor with shape [batch_size + 1].
- wqkv (TensorValue) – The concatenated Q, K and V projection weights.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – PagedCacheValues object for managing key-value cache.
- layer_idx (TensorValue) – TensorValue representing the layer index, expected to have dtype uint32.
- n_heads (int) – Number of Query attention heads.
- bias (TensorValue | None) – Optional bias vector concatenated as [q, k, v].
- _output_dim (int | None) – Optional output dimension. If not provided, the output dimension will be [n_heads * head_dim].
-
Returns:
-
Query projection tensor.
-
Return type:
fused_qkv_ragged_matmul_quantized()
max.nn.kernels.fused_qkv_ragged_matmul_quantized(kv_params, input, input_row_offsets, wqkv, kv_collection, layer_idx, n_heads, quantization_config, perm_idx=None, bias=None)
Computes fused query, key, and value projections with ragged input and
quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged
tensor.
input_row_offsets indicates where each batch starts and ends in input
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Parameters:
-
- kv_params (KVCacheParams)
- input (TensorValue)
- input_row_offsets (TensorValue)
- wqkv (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
- n_heads (int)
- quantization_config (QuantizationConfig)
- perm_idx (TensorValue | None)
- bias (TensorValue | None)
-
Return type:
grouped_dynamic_scaled_fp8_matmul()
max.nn.kernels.grouped_dynamic_scaled_fp8_matmul(hidden_states, weight, a_scales, b_scales, expert_start_indices, expert_ids, expert_usage_stats_host, input_scale_spec, weight_scale_spec, out_type=bfloat16, tokens_padded_per_expert=False)
Grouped blockwise scaled matmul used in MoE layer.
Perform a grouped blockwise scaled matmul of two tensors with scaling factors. hidden_states and expert_start_indices are used together to implement the ragged tensor.
-
Parameters:
-
- hidden_states (TensorValue) – The first tensor to multiply. (2D tensor)
- weight (TensorValue) – The second tensor to multiply, must be transposed. (3D tensor)
- a_scales (TensorValue) – The scaling factors for the first tensor. (2D tensor)
- b_scales (TensorValue) – The scaling factors for the second tensor. (3D tensor)
- expert_start_indices (TensorValue) – indicates where each group starts and ends in hidden_states.
- expert_ids (TensorValue) – The id of the expert for each group in hidden_states.
- expert_usage_stats_host (TensorValue) – The maximum number of tokens assigned to any expert, and the number of active experts.
- input_scale_spec (InputScaleSpec) – The scaling granularity for the input tensor.
- weight_scale_spec (WeightScaleSpec) – The scaling granularity for the weight tensor.
- tokens_padded_per_expert (bool) – If True, It’s guaranteed that the number of tokens for each local expert will be padded, so that a_scales is aligned to 16 bytes. This is needed by the optimized grouped matmul kernel.
- out_type (DType)
-
Returns:
-
The result of the matmul operation.
-
Return type:
grouped_dynamic_scaled_mxfp4_matmul()
max.nn.kernels.grouped_dynamic_scaled_mxfp4_matmul(hidden_states, weight, a_scales, b_scales, expert_start_indices, expert_ids, expert_usage_stats_host, out_type=bfloat16, estimated_total_m=None)
Performs grouped NVFP4 matmul for MoE layers.
Performs a grouped matmul with MXFP4 (4-bit) quantized inputs and weights. The inputs are packed as uint8 (2 MXFP4 values per byte) with float8_e8m0fnu scaling factors. MXFP4 uses fixed 1D block scaling with 32 elements per scale factor along the K dimension.
hidden_states and expert_start_indices together implement the ragged
tensor representation for variable-length expert inputs.
-
Parameters:
-
- hidden_states (TensorValue) – The input activations with shape
[total_tokens, K/2]where K is the unpacked hidden dimension. Dtype must be uint8 (packed MXFP4). - weight (TensorValue) – The expert weights with shape
[num_experts, N, K/2]. Dtype must be uint8 (packed MXFP4). - a_scales (TensorValue) – Scaling factors for inputs with shape
[num_scale_rows, K/32]. Dtype must be float8_e8m0fnu. - b_scales (TensorValue) – Scaling factors for weights with shape
[num_experts, N, K/32]. Dtype must be float8_e8m0fnu. - expert_start_indices (TensorValue) – Indices indicating where each expert’s tokens
start in
hidden_states. - expert_ids (TensorValue) – The expert ID for each group.
- expert_usage_stats_host (TensorValue) – A tensor containing [max_tokens_per_expert, num_active_experts].
- out_type (DType) – Output dtype. Defaults to bfloat16.
- estimated_total_m (TensorValue | None) – The estimated total number of tokens.
- hidden_states (TensorValue) – The input activations with shape
-
Returns:
-
The matmul result with shape
[total_tokens, N]and dtypeout_type. -
Return type:
grouped_matmul_block_scaled()
max.nn.kernels.grouped_matmul_block_scaled(hidden_states, weight, a_scales, b_scales, expert_start_indices, a_scale_offsets, expert_ids, expert_scales, expert_usage_stats_host, out_type=bfloat16, estimated_total_m=None)
Performs grouped NVFP4 matmul for MoE layers.
Performs a grouped matmul with NVFP4 (4-bit) quantized inputs and weights. The inputs are packed as uint8 (2 NVFP4 values per byte) with float8_e4m3fn scaling factors. NVFP4 uses fixed 1D block scaling with 16 elements per scale factor along the K dimension.
hidden_states and expert_start_indices together implement the ragged
tensor representation for variable-length expert inputs.
-
Parameters:
-
- hidden_states (TensorValue) – The input activations with shape
[total_tokens, K/2]where K is the unpacked hidden dimension. Dtype must be uint8 (packed NVFP4). - weight (TensorValue) – The expert weights with shape
[num_experts, N, K/2]. Dtype must be uint8 (packed NVFP4). - a_scales (TensorValue) – Scaling factors for inputs with shape
[num_scale_rows, K_groups, 32, 4, 4]. Dtype must be float8_e4m3fn. - b_scales (TensorValue) – Scaling factors for weights with shape
[num_experts, N_groups, K_groups, 32, 4, 4]. Dtype must be float8_e4m3fn. - expert_start_indices (TensorValue) – Indices indicating where each expert’s tokens
start in
hidden_states. - a_scale_offsets (TensorValue) – The offsets of the input scale tiles for each expert.
- expert_ids (TensorValue) – The expert ID for each group.
- expert_scales (TensorValue) – Per-expert scaling factors with shape
[num_experts]. Dtype must be float32. Multiplied with the matmul output in the epilogue. - expert_usage_stats_host (TensorValue) – A tensor containing [max_tokens_per_expert, num_active_experts].
- out_type (DType) – Output dtype. Defaults to bfloat16.
- estimated_total_m (TensorValue | None) – The estimated total number of tokens.
- hidden_states (TensorValue) – The input activations with shape
-
Returns:
-
The matmul result with shape
[total_tokens, N]and dtypeout_type. -
Return type:
grouped_matmul_ragged()
max.nn.kernels.grouped_matmul_ragged(hidden_states, weight, expert_start_indices, expert_ids, expert_usage_stats_host)
Grouped matmul used in MoE layer.
hidden_states and expert_start_indices are used together to implement the ragged tensor. expert_start_indices indicates where each group starts and ends in hidden_states
expert_ids is the id of the expert for each group in hidden_states
expert_usage_stats_host is the maximum number of tokens assigned to any expert, and the number of active experts.
-
Parameters:
-
- hidden_states (TensorValue)
- weight (TensorValue)
- expert_start_indices (TensorValue)
- expert_ids (TensorValue)
- expert_usage_stats_host (TensorValue)
-
Return type:
grouped_quantize_dynamic_block_scaled_fp4()
max.nn.kernels.grouped_quantize_dynamic_block_scaled_fp4(input, row_offsets, scales_offsets, expert_ids, sf_tensor, sf_vector_size=16, scales_type=float8_e4m3fn, out_type=uint8)
Grouped dynamic FP4 quantization for MoE experts.
Quantizes a concatenated token tensor where different row ranges belong to different experts, each with its own tensor-wise scale factor.
-
Parameters:
-
- input (TensorValue) – The concatenated input tensor. Shape:
[total_tokens, K], dtypebfloat16. - row_offsets (TensorValue) – Cumulative token offsets per expert.
Shape:
[num_experts + 1], dtypeuint32. - scales_offsets (TensorValue) – Per-expert scale tile offset corrections.
Shape:
[num_experts], dtypeuint32. - expert_ids (TensorValue) – Expert ID mapping (typically identity).
Shape:
[num_experts], dtypeint32. - sf_tensor (TensorValue) – Per-expert tensor-wise scale factors.
Shape:
[num_experts], dtypefloat32. - sf_vector_size (int) – The block size for the scaling factors.
- scales_type (DType) – Scale factor dtype.
float8_e4m3fnfor NVFP4. - out_type (DType) – Output dtype.
uint8for packed FP4.
- input (TensorValue) – The concatenated input tensor. Shape:
-
Returns:
-
The quantized tensor
[total_tokens, K // 2]and scales in rank-5 interleaved layout[total_m_tiles, K_tiles, 32, 4, 4]. -
Return type:
kv_cache_copy_pages_d2h()
max.nn.kernels.kv_cache_copy_pages_d2h(device_kv_collection, device_page_ids, host_kv_blocks, host_page_ids, layer_idx, device_ref)
Copy KV cache pages from GPU to CPU for a single layer.
Performs async GPU->CPU copy of specified pages for layer-wise KV cache offloading.
-
Parameters:
-
- device_kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Source KV cache on GPU.
- device_page_ids (TensorValue) – Source page IDs to read from GPU.
- host_kv_collection – Destination KV cache on CPU.
- host_page_ids (TensorValue) – Destination page IDs to write to CPU. Must have same length as device_page_ids.
- layer_idx (int) – Which layer to copy.
- device_ref (DeviceRef) – Device for the GPU context.
- host_kv_blocks (BufferValue)
-
Return type:
-
None
kv_cache_ragged_2m_iadd()
max.nn.kernels.kv_cache_ragged_2m_iadd(kv_params, a, kv_collection, input_row_offsets, lora_end_idx, batch_seq_len, layer_idx)
In-place add to paged KV cache with interleaved K/V layout.
Performs an in-place addition of new key-value projections to paged KV cache. The input tensor a uses a “2M” layout where keys and values are interleaved: rows [0, m) contain keys and rows [m, 2m) contain values, where m is the number of tokens.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache configuration parameters.
- a (TensorValue) – Input tensor with interleaved K/V data, shape (2*m, hidden_size) where m is the number of tokens. Rows [0, m) are keys, rows [m, 2m) are values.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection containing cache blocks, cache lengths, lookup tables, and max lengths tensors.
- input_row_offsets (TensorValue) – Ragged tensor offsets indicating where each batch starts and ends
- lora_end_idx (TensorValue) – End index of LoRA token portion. Marks the boundary between LoRA sequences and base model sequences in the batch.
- batch_seq_len (TensorValue) – Total sequence length in the batch. Used for indexing into the value portion of a.
- layer_idx (TensorValue) – The transformer layer index to update in the KV cache.
-
Raises:
-
- ValueError – If a does not have rank 2.
- ValueError – If input_row_offsets does not have rank 1.
-
Return type:
-
None
kv_cache_ragged_radd()
max.nn.kernels.kv_cache_ragged_radd(kv_params, a, kv_collection, input_row_offsets, batch_offset, layer_idx)
This function adds a tensor to a slice of the KVCache, sliced on the batch dimension.
This expects that the requests which should be sliced out are contiguous and in the front of the tensor, and we’re only adding to the last requests in the batch.
-
Parameters:
-
- a (TensorValue) – The tensor to add to the KVCache.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The KVCache collection to add to.
- input_row_offsets (TensorValue) – The offsets of the input tensor.
- batch_offset (TensorValue) – The batch to start applying the r-add to.
- layer_idx (int) – The layer index to add to.
- kv_params (KVCacheParams)
-
Return type:
-
None
kv_cache_store_paged_padded()
max.nn.kernels.kv_cache_store_paged_padded(kv_collection, x_cache, valid_lengths, layer_idx, *, key_or_value)
Stores key or value tensor into the paged KV cache (padded inputs).
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- x_cache (TensorValue)
- valid_lengths (TensorValue)
- layer_idx (TensorValue)
- key_or_value (int)
-
Return type:
-
None
kv_cache_store_paged_ragged()
max.nn.kernels.kv_cache_store_paged_ragged(kv_collection, x_cache, input_row_offsets, layer_idx, *, key_or_value)
Stores key or value tensor into the paged KV cache (ragged inputs).
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- x_cache (TensorValue)
- input_row_offsets (TensorValue)
- layer_idx (TensorValue)
- key_or_value (int)
-
Return type:
-
None
learnable_2d_interp_pos_emb()
max.nn.kernels.learnable_2d_interp_pos_emb(x, weight, grid_thws, time_weight)
Applies learnable 2D interpolated position embedding (Kimi K2.5).
For each video described by grid_thws, bicubic-interpolates weight
from (H, W) to (h, w), optionally adds temporal sincos embedding when
t > 1, and adds the result element-wise to x.
-
Parameters:
-
- x (TensorValue) – Patch embeddings of shape
(L, dim). - weight (TensorValue) – Learnable 2D grid of shape
(H, W, dim). - grid_thws (TensorValue) – Per-video
(t, h, w)of shape(N, 3), dtype int64. - time_weight (TensorValue) – 1D sincos temporal embedding of shape
(num_frames, dim), dtype float32.
- x (TensorValue) – Patch embeddings of shape
-
Returns:
-
Tensor of shape
(L, dim)with position embeddings added. -
Raises:
-
ValueError – On invalid input shapes or dtypes.
-
Return type:
lmcache_offload()
max.nn.kernels.lmcache_offload(output, paged_cache, slot_mapping, start_token, end_token, page_size, num_kv_heads, head_dim, kv_dim, device_ref)
Offload KV cache data from paged format to external contiguous format.
Used by LMCache connector to copy KV data from MAX’s paged cache layout to LMCache’s contiguous KV_2LTD format for external storage.
-
Parameters:
-
- output (BufferValue) – Output buffer [kv_dim, num_layers, num_tokens, hidden_dim] where hidden_dim = num_kv_heads * head_dim.
- paged_cache (TensorValue) – Input tensor [total_num_blocks, kv_dim, num_layers, page_size, num_kv_heads, head_dim].
- slot_mapping (TensorValue) – Token to slot mapping [total_tokens].
- start_token (TensorValue) – Starting token index scalar [1].
- end_token (TensorValue) – Ending token index (exclusive) scalar [1].
- page_size (int) – Number of tokens per page in the paged cache.
- num_kv_heads (int) – Number of KV attention heads.
- head_dim (int) – Dimension of each attention head.
- kv_dim (int) – KV dimension (2 for standard K/V, 1 for MLA).
- device_ref (DeviceRef) – Device reference for the operation.
-
Return type:
-
None
lmcache_onload()
max.nn.kernels.lmcache_onload(paged_cache, input_tensor, slot_mapping, start_token, end_token, page_size, num_kv_heads, head_dim, kv_dim, device_ref)
Onload KV cache data from external contiguous format to paged format.
Used by LMCache connector to copy KV data from LMCache’s contiguous KV_2LTD format into MAX’s paged cache layout.
-
Parameters:
-
- paged_cache (BufferValue) – Output buffer [total_num_blocks, kv_dim, num_layers, page_size, num_kv_heads, head_dim].
- input_tensor (TensorValue) – Input tensor [kv_dim, num_layers, num_tokens, hidden_dim] where hidden_dim = num_kv_heads * head_dim.
- slot_mapping (TensorValue) – Token to slot mapping [total_tokens].
- start_token (TensorValue) – Starting token index scalar [1].
- end_token (TensorValue) – Ending token index (exclusive) scalar [1].
- page_size (int) – Number of tokens per page in the paged cache.
- num_kv_heads (int) – Number of KV attention heads.
- head_dim (int) – Dimension of each attention head.
- kv_dim (int) – KV dimension (2 for standard K/V, 1 for MLA).
- device_ref (DeviceRef) – Device reference for the operation.
-
Return type:
-
None
masked_flash_attention_gpu()
max.nn.kernels.masked_flash_attention_gpu(q, k, v, mask, scale)
Computes flash attention using a materialized additive mask.
-
Parameters:
-
- q (TensorValue) – Query tensor of shape [batch, q_seq_len, num_heads, head_dim]
- k (TensorValue) – Key tensor of shape [batch, kv_seq_len, num_heads, head_dim]
- v (TensorValue) – Value tensor of shape [batch, kv_seq_len, num_heads, head_dim]
- mask (TensorValue) – Additive mask tensor. Rank 3 of shape [batch, q_seq_len, kv_seq_len] broadcasts across attention heads. Rank 4 of shape [batch, num_heads, q_seq_len, kv_seq_len] applies a per-head bias.
- scale (float) – Scaling factor for attention scores.
-
Returns:
-
Output tensor of shape [batch, q_seq_len, num_heads, head_dim]
-
Return type:
matmul_k_cache_ragged()
max.nn.kernels.matmul_k_cache_ragged(kv_params, hidden_states, input_row_offsets, weight, kv_collection, layer_idx)
Computes key projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Parameters:
-
- kv_params (KVCacheParams)
- hidden_states (TensorValue)
- input_row_offsets (TensorValue)
- weight (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
-
Return type:
-
None
matmul_k_cache_ragged_scaled_float8()
max.nn.kernels.matmul_k_cache_ragged_scaled_float8(kv_params, hidden_states, input_row_offsets, weight, input_scale, weight_scale, kv_collection, scales_granularity_mnk, layer_idx)
Computes key projections with ragged input with FP8 block scaling.
-
Parameters:
-
- kv_params (KVCacheParams) – KVCacheParams object containing key-value cache parameters.
- hidden_states (TensorValue) – TensorValue representing the input tensor with shape [M=total_seq_len, K=hidden_dim].
- input_row_offsets (TensorValue) – TensorValue indicating the start and end of each batch in the input tensor with shape [batch_size + 1].
- weight (TensorValue) – TensorValue representing the weight tensor with shape [N=num_heads, K=hidden_dim].
- input_scale (TensorValue) – TensorValue representing the input scale tensor with shape [ceildiv(K / BLOCK_SIZE_K), ceildiv(M / BLOCK_SIZE_M)].
- weight_scale (TensorValue) – TensorValue representing the weight scale tensor with shape [ceildiv(N / BLOCK_SIZE_N), ceildiv(K / BLOCK_SIZE_K)].
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – PagedCacheValues object for managing key-value cache.
- scales_granularity_mnk (tuple[int, int, int]) – tuple[int, int, int] representing the scaling (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K).
- layer_idx (TensorValue) – TensorValue representing the layer index, expected to have dtype uint32.
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
-
None
matmul_kv_cache_ragged()
max.nn.kernels.matmul_kv_cache_ragged(kv_params, hidden_states, input_row_offsets, weight, kv_collection, layer_idx)
Computes key and value projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Parameters:
-
- kv_params (KVCacheParams)
- hidden_states (TensorValue)
- input_row_offsets (TensorValue)
- weight (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
-
Return type:
-
None
matmul_static_scaled_float8()
max.nn.kernels.matmul_static_scaled_float8(input, weight, input_scale, weight_scale)
Performs a static-scaled float8 matrix multiplication.
Computes input @ weight.T where both tensors are float8, dequantized
using the provided per-tensor CPU scalar scales before accumulation.
The output is always bfloat16.
-
Parameters:
-
- input (TensorValue) – Input tensor of rank 2 and dtype
float8_e4m3fnorfloat8_e4m3fnuz. - weight (TensorValue) – Weight tensor of rank 2 and matching float8 dtype, laid out
so that the K dimension matches
input.shape[1]. - input_scale (TensorValue) – Scalar scale factor for
input(shape[]or[1]), must reside on CPU. - weight_scale (TensorValue) – Scalar scale factor for
weight(shape[]or[1]), must reside on CPU.
- input (TensorValue) – Input tensor of rank 2 and dtype
-
Returns:
-
A
TensorValueof shape[input.shape[0], weight.shape[0]]and dtypebfloat16. -
Raises:
-
ValueError – If scale shapes are not scalar, input or weight are not rank 2, K dimensions do not match, or scales are not on CPU.
-
Return type:
merge_ragged_tensors()
max.nn.kernels.merge_ragged_tensors(a, a_row_offsets, b, b_row_offsets)
Merges two ragged tensors into a single ragged tensor.
Both ragged tensors must have the same batch size (same number of row offsets). This function interleaves the rows from each tensor based on their row offsets.
-
Parameters:
-
- a (TensorValue) – The first ragged tensor of shape [total_a_rows, …].
- a_row_offsets (TensorValue) – The row offsets of the first ragged tensor,indicating where each batch starts and ends in a.
- b (TensorValue) – The second ragged tensor of shape [total_b_rows, …].
- b_row_offsets (TensorValue) – The row offsets of the second ragged tensor, indicating where each batch starts and ends in b.
-
Returns:
-
- The merged ragged tensor with shape [total_a_rows + total_b_rows, …].
- The merged row offsets with the same shape as input row offsets.
-
Return type:
-
A tuple of two tensors
a = [1, 2, 3, 4, 5, 6]
a_row_offsets = [0, 2, 6]
b = [7, 8, 9, 10]
b_row_offsets = [0, 3, 4]
merged_tensor, merged_row_offsets = merge_ragged_tensors(
a, a_row_offsets, b, b_row_offsets)
merged_tensor = [1, 2, 7, 8, 9, 3, 4, 5, 6, 10]
merged_row_offsets = [0, 5, 10]mla_decode_graph()
max.nn.kernels.mla_decode_graph(q, kv, input_row_offsets, freqs_cis, kv_norm_gamma, w_uk, w_uv, kv_params, kv_collection, layer_idx, mask_variant, scale, epsilon, v_head_dim, scalar_args, *, w_uk_scale=None, w_uv_scale=None, quant_config=None)
This is a manually fused kernel that performs the following operations:
- Apply RoPE to the query and the key cache (in-place).
- Apply RMSNorm to the non-rope portion of the key cache (in-place).
- Project q_nope to kv_latent_dim through a fp8 batched matmul: q_nope_proj = q_nope_t @ w_uk
- Concatenate q_nope_proj and q_rope: q_full = concat(q_nope_proj, q_rope, axis=2)
- Perform MLA decode
- Project raw_output to v_head_dim through another fp8 batched matmul: output = raw_output_t @ w_uv
-
Parameters:
-
- q (TensorValue) – Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim].
- kv (TensorValue) – KV latent tensor from the first projection. Shape: [num_tokens, cache_head_dim] where cache_head_dim = kv_lora_rank + qk_rope_head_dim.
- input_row_offsets (TensorValue) – Indicates where each request starts and ends in input. This is a 1D tensor of shape [num_batches + 1].
- freqs_cis (TensorValue) – Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim].
- kv_a_proj_layernorm – RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank].
- w_uk (TensorValue) – Weight matrix for projecting q_nope to kv_latent_dim. Shape: [num_heads, kv_latent_dim, qk_nope_head_dim].
- w_uv (TensorValue) – Weight matrix for projecting MLA decode output to v_head_dim. Shape: [num_heads, v_head_dim, kv_latent_dim].
- kv_params (KVCacheParams) – KVCacheParams
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV Cache object.
- layer_idx (TensorValue) – Layer index.
- mask_variant (MHAMaskVariant) – The attention mask variant controlling masking behavior.
- scale (float) – Scale for the attention calculation.
- epsilon (float) – Small constant for numerical stability in RMSNorm.
- v_head_dim (int) – Dimension of the V heads.
- scalar_args (TensorValue) – Pre-computed dispatch scalar args (GPU buffer) for CUDA graph capture.
- w_uk_scale (TensorValue | None) – Optional FP8 scale tensor for w_uk.
- w_uv_scale (TensorValue | None) – Optional FP8 scale tensor for w_uv.
- quant_config (QuantConfig | None) – Optional quantization config. When set, scales are required.
- kv_norm_gamma (TensorValue)
-
Returns:
-
Tensor of shape [total_seq_len, num_heads, v_head_dim].
-
Return type:
mla_fp8_index_top_k()
max.nn.kernels.mla_fp8_index_top_k(q, q_s, input_row_offsets, k_collection, layer_idx, top_k, quantization_granularity, mask_variant=MHAMaskVariant.CAUSAL_MASK)
Computes top-k indices for MLA FP8 indexed attention scores.
This function computes FP8 matmul between queries and cached keys (with scales), applies masking, and returns the indices of the top-k highest-scoring keys per token. Scores are aggregated (summed) across all attention heads.
-
Parameters:
-
- q (TensorValue) – Query tensor of shape [total_seq_len, num_heads, head_dim] in FP8.
- q_s (TensorValue) – Query scales tensor of shape [total_seq_len, num_heads] in float32.
- input_row_offsets (TensorValue) – Input row offsets tensor of shape [batch_size + 1].
- k_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache collection. Must be FP8 quantized with scales.
- layer_idx (TensorValue) – Layer index for cache lookup.
- top_k (int) – Requested number of top indices per token.
- quantization_granularity (int) – Quantization granularity for the K cache.
- mask_variant (MHAMaskVariant) – The mask variant to use (NULL or CAUSAL_MASK).
-
Returns:
-
Output tensor of shape [total_seq_len, effective_k] containing top-k key indices per token, where effective_k = min(top_k, max_num_keys). Invalid positions are filled with -1.
-
Return type:
mla_prefill_decode_graph()
max.nn.kernels.mla_prefill_decode_graph(q, kv, input_row_offsets, freqs_cis, kv_norm_gamma, buffer_row_offsets, cache_offsets, buffer_length, w_k, w_uk, w_uv, kv_params, kv_collection, layer_idx, mask_variant, scale, epsilon, v_head_dim, scalar_args, *, w_k_scale=None, w_uk_scale=None, w_uv_scale=None, quant_config=None)
Fused MLA prefill/decode kernel for FP8.
Switches between prefill and decode based on the maximum sequence length in the batch. See mla_prefill_graph and mla_decode_graph for the dedicated paths.
-
Parameters:
-
- q (TensorValue) – Combined query tensor with nope+rope parts.
- kv (TensorValue) – KV latent tensor for current sequence.
- input_row_offsets (TensorValue) – Row offsets for the batch.
- freqs_cis (TensorValue) – RoPE frequencies tensor.
- kv_norm_gamma (TensorValue) – RMSNorm gamma for KV cache.
- buffer_row_offsets (TensorValue) – One-shot prefill buffer row offsets.
- cache_offsets (TensorValue) – One-shot prefill cache offsets.
- buffer_length (TensorValue) – One-shot prefill buffer length tensor.
- w_k (TensorValue) – Prefill K up-projection weights.
- w_uk (TensorValue) – Decode query-projection weights.
- w_uv (TensorValue) – Decode output-projection / prefill V-projection weights.
- kv_params (KVCacheParams) – KV cache parameters.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache values.
- layer_idx (TensorValue) – Layer index (uint32).
- mask_variant (MHAMaskVariant) – Attention mask variant.
- scale (float) – Attention scale.
- epsilon (float) – RMSNorm epsilon.
- v_head_dim (int) – Value head dimension for output tensor shape.
- scalar_args (TensorValue) – Pre-computed dispatch scalar args (GPU buffer) for CUDA graph capture.
- w_k_scale (TensorValue | None) – Optional FP8 scale tensor for w_k.
- w_uk_scale (TensorValue | None) – Optional FP8 scale tensor for w_uk.
- w_uv_scale (TensorValue | None) – Optional FP8 scale tensor for w_uv.
- quant_config (QuantConfig | None) – Optional quantization config. When set, scales are required.
-
Returns:
-
Tensor of shape [total_seq_len, num_heads, v_head_dim].
-
Return type:
mla_prefill_graph()
max.nn.kernels.mla_prefill_graph(q, kv, input_row_offsets, freqs_cis, kv_norm_gamma, buffer_row_offsets, cache_offsets, buffer_length, w_k, w_uv, kv_params, kv_collection, layer_idx, mask_variant, scale, epsilon, v_head_dim, *, w_k_scale=None, w_uv_scale=None, quant_config=None)
This is a manually fused kernel that performs the following operations:
- Apply RoPE to the query and the key cache (in-place).
- Apply RMSNorm to the non-rope portion of the key cache (in-place).
- Copy the KV latent values from PagedKVCache to a contiguous buffer.
- Quantize the KV latent values to fp8.
- Up-project the latent KV values to full K and V through two matmuls.
- Perform MLA prefill.
-
Parameters:
-
- q (TensorValue) – Combined query tensor containing both nope and rope parts. Shape: [tot_seq_len, num_heads, qk_nope_head_dim + qk_rope_head_dim].
- kv (TensorValue) – KV latent tensor from the first projection. Shape: [num_tokens, cache_head_dim] where cache_head_dim = kv_lora_rank + qk_rope_head_dim.
- input_row_offsets (TensorValue) – Indicates where each request starts and ends in input. This is a 1D tensor of shape [num_batches + 1].
- freqs_cis (TensorValue) – Precomputed RoPE frequency values for rotary position embeddings. Shape: [max_seq_len, qk_rope_head_dim].
- kv_a_proj_layernorm – RMSNorm gamma weights for normalizing the KV cache. Shape: [kv_lora_rank].
- buffer_row_offsets (TensorValue) – Indicates where each request’s KV latent values should be stored in the contiguous buffer. This is a 1D tensor of shape [num_batches + 1].
- cache_offsets (TensorValue) – Indicates the starting token position in the KV cache from which to copy KV latent values for each request. This is a 1D tensor of shape [num_batches + 1].
- buffer_length (TensorValue) – The total number of tokens in the KV cache. Scalar.
- w_k (TensorValue) – Weight matrix for up-projecting latent KV values to full K. Shape: [num_heads * qk_nope_head_dim, kv_latent_dim].
- w_uv (TensorValue) – Weight tensor for up-projecting latent KV values to full V. Shape: [num_heads, v_head_dim, kv_latent_dim].
- kv_params (KVCacheParams) – KVCacheParams
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV Cache object.
- layer_idx (TensorValue) – Layer index.
- mask_variant (MHAMaskVariant) – The attention mask variant controlling masking behavior.
- scale (float) – Scale for the attention calculation.
- epsilon (float) – Small constant for numerical stability in RMSNorm.
- v_head_dim (int) – Dimension of the V heads.
- w_k_scale (TensorValue | None) – Optional FP8 scale tensor for w_k.
- w_uv_scale (TensorValue | None) – Optional FP8 scale tensor for w_uv.
- quant_config (QuantConfig | None) – Optional quantization config. When set, scales are required.
- kv_norm_gamma (TensorValue)
-
Returns:
-
Tensor of shape [total_seq_len, num_heads, v_head_dim].
-
Return type:
moe_create_indices()
max.nn.kernels.moe_create_indices(topk_ids, num_local_experts, *, needs_scales_offset=False, scales_alignment=128)
Creates indices for the MoE layer.
-
Parameters:
-
- topk_ids (TensorValue) – The expert assignments for each token from the router.
- num_local_experts (int) – The number of experts on this device.
- needs_scales_offset (bool)
- scales_alignment (int)
-
Returns:
-
- token_expert_order: The reordered token indices, grouped by assigned expert.
- expert_start_indices: The starting index for each expert’s token group in the reordered sequence.
- restore_token_order: The indices to restore original token ordering after expert computation.
- expert_ids: ids of active experts selected for tokens
- expert_usage_stats: The maximum number of tokens assigned to any expert, and the number of active experts.
-
Return type:
-
A tuple of five tensors
moe_router_group_limited()
max.nn.kernels.moe_router_group_limited(expert_scores, expert_bias, n_routed_experts, n_experts_per_tok, n_groups, topk_group, norm_weights, routed_scaling_factor)
Group limited MoE router.
When n_groups > 1, selects up to topk_group expert groups, then
picks n_experts_per_tok experts within those groups (DeepSeek-V3 style).
When n_groups == 1, there is only one group, so group selection is
skipped and routing uses the dedicated GPU single-group path
(mo.moe.single.group.router, implemented as single_group_router in
Mojo). In that case topk_group is not used by the kernel.
-
Parameters:
-
- expert_scores (TensorValue) – The scores for each expert for each token. Shape: [num_tokens, n_routed_experts].
- expert_bias (TensorValue) – The bias for each expert. Shape: [n_routed_experts].
- n_routed_experts (int) – The total number of experts. Must be divisible by n_groups.
- n_experts_per_tok (int) – The number of experts to be selected per token.
- n_groups (int) – The total number of expert groups. Must be divisible by n_routed_experts.
- topk_group (int) – The maximum number of expert groups that a token will be routed to.
- norm_weights (bool) – Whether to normalize the selected expert weights when n_groups > 1. When n_groups == 1, normalization is currently always enabled (norm_weights is treated as True) so behavior matches the previous graph path that always divided weights by their sum per token.
- routed_scaling_factor (float)
-
Returns:
-
- expert_indices: The indices of the routed experts for each token. Shape: [num_tokens, n_experts_per_tok].
- expert_weights: The weights of the routed experts for each token. Shape: [num_tokens, n_experts_per_tok].
-
Return type:
-
A tuple of two tensors
mxfp4_dequant()
max.nn.kernels.mxfp4_dequant(packed_weights, scales, out_type=bfloat16)
Dequantizes MXFP4 packed weights to BF16 or FP8 on GPU.
Supports rank 2 [N, K//2] and rank 3 [E, N, K//2] inputs.
For rank 3, leading dims are flattened to 2D, dequantized, and reshaped back.
-
Parameters:
-
- packed_weights (TensorValue) – Packed weights in uint8 (2 FP4 values per byte).
Shape
[N, K//2]or[E, N, K//2]. - scales (TensorValue) – Block scales in float8_e8m0fnu.
Shape
[N, K//32]or[E, N, K//32]. - out_type (DType) – Output dtype (bfloat16 or float8_e4m3fn).
- packed_weights (TensorValue) – Packed weights in uint8 (2 FP4 values per byte).
Shape
-
Returns:
-
Dequantized tensor
[N, K]or[E, N, K]in out_type. -
Return type:
needs_fp8_fnuz_conversion()
max.nn.kernels.needs_fp8_fnuz_conversion()
Checks if FP8 E4M3FN to FNUZ conversion is needed for AMD GPUs.
-
Returns:
-
Trueif running on AMD GPU with CDNA3 architecture,Falseotherwise. -
Return type:
normalize_e4m3fn_to_e4m3fnuz()
max.nn.kernels.normalize_e4m3fn_to_e4m3fnuz(weight, weight_scale)
Converts E4M3FN weights to E4M3FNUZ format for AMD GPUs.
This conversion is necessary because AMD GPUs use the E4M3FNUZ format while NVIDIA GPUs use E4M3FN. The key differences are:
- The bit pattern 10000000 (-128) represents zero in E4M3FN but NaN in E4M3FNUZ
- For the same bit representation, E4M3FNUZ values are half of E4M3FN values
-
Parameters:
-
- weight (TensorValue) – The weight tensor in E4M3FN format.
- weight_scale (TensorValue) – The weight scale factor.
-
Returns:
-
Tuple of (converted_weight, adjusted_weight_scale, adjusted_input_scale).
-
Return type:
quantize_dynamic_block_scaled_fp4()
max.nn.kernels.quantize_dynamic_block_scaled_fp4(input, tensor_sf, sf_vector_size=16, scales_type=float8_e4m3fn, out_type=uint8)
Dynamically quantize the input tensor to fp4-e2m1fn.
-
Parameters:
-
- input (TensorValue) – The input tensor to quantize. Shape: [seq_len, hidden_size]
- tensor_sf (TensorValue | float) – The tensor-wise scale factor (inverted as per quantization kernel requirement).
- sf_vector_size (int) – The block size for the scaling factors. 16 for NVFP4, 32 for MXFP4.
- out_type (DType) – The type of the output tensor.
- scales_type (DType) – The type of the scales tensor.
float8_e4m3fnfor NVFP4,float8_e8m0fnufor MXFP4.
-
Returns:
-
rank-5 interleaved on NVIDIA SM100, rank-2
[M, K // sf_vector_size]otherwise. -
Return type:
-
The quantized tensor and scales. Scales layout depends on hardware
quantize_dynamic_block_scaled_mxfp4()
max.nn.kernels.quantize_dynamic_block_scaled_mxfp4(input, scales_type=float8_e8m0fnu, out_type=uint8)
Dynamically quantize the input tensor to fp4-e2m1fn.
-
Parameters:
-
- input (TensorValue) – The input tensor to quantize. Shape: [seq_len, hidden_size]
- out_type (DType) – The type of the output tensor.
- scales_type (DType) – The type of the scales tensor.
-
Returns:
-
The quantized tensor in [seq_len, hidden_size // 2] layout and the scales in [seq_len, hidden_size // 32] layout.
-
Return type:
quantize_dynamic_scaled_float8()
max.nn.kernels.quantize_dynamic_scaled_float8(input, input_scale_spec, weight_scale_spec, scale_ub=1200.0, group_size_or_per_token=-1, out_type=float8_e4m3fn, scales_type=bfloat16)
Dynamically quantize the input tensor to fp8.
-
Parameters:
-
- input (TensorValue) – The input tensor to quantize.
- scale_ub (float) – The upper bound of the scale factor.
- group_size_or_per_token (int) – The group size for quantization. When set to -1, the quantization is column-wise.
- out_type (DType) – The type of the output tensor.
- scales_type (DType) – The type of the scales tensor.
- input_scale_spec (InputScaleSpec)
- weight_scale_spec (WeightScaleSpec)
-
Returns:
-
The quantized tensor and the scales.
-
Return type:
quantize_static_scaled_float8()
max.nn.kernels.quantize_static_scaled_float8(x, scale, scale_is_inverted=True, out_type=float8_e4m3fn)
Quantizes a rank-2 tensor to float8 using a static per-tensor scale.
-
Parameters:
-
- x (TensorValue) – Input tensor to quantize. Must be rank 2 with dtype
float16,bfloat16, orfloat32. - scale (TensorValue) – Scalar scale factor (shape
[]or[1]) residing on CPU. - scale_is_inverted (bool) – When
True(default),scaleis interpreted as1 / max_val(inverted). WhenFalse, it is the raw absolute-max scale. - out_type (DType) – Output dtype. Defaults to
DType.float8_e4m3fn.
- x (TensorValue) – Input tensor to quantize. Must be rank 2 with dtype
-
Returns:
-
A quantized
TensorValuewith shape equal toxand dtypeout_type. -
Raises:
-
ValueError – If
scaleis not a scalar,xis not rank 2,xdtype is unsupported, orscaleis not on CPU. -
Return type:
quantize_tensor_dynamic_scaled_float8()
max.nn.kernels.quantize_tensor_dynamic_scaled_float8(input, input_scale_spec, weight_scale_spec, scale_ub=1200.0, group_size_or_per_token=-1, out_type=float8_e4m3fn, scales_type=bfloat16)
Quantizes a rank-2 tensor to float8 using a dynamic per-tensor scale.
-
Parameters:
-
- input (TensorValue) – The input tensor to quantize.
- scale_ub (float) – The upper bound of the scale factor.
- group_size_or_per_token (int) – The group size for quantization. When set to -1, the quantization is column-wise.
- out_type (DType) – The type of the output tensor.
- scales_type (DType) – The type of the scales tensor.
- input_scale_spec (InputScaleSpec)
- weight_scale_spec (WeightScaleSpec)
-
Returns:
-
The quantized tensor and the scales.
-
Return type:
repack_gguf_quantized_weights()
max.nn.kernels.repack_gguf_quantized_weights(weight, quantization_encoding)
Repacks GGUF quantized weights for the given encoding.
-
Parameters:
-
- weight (TensorValue)
- quantization_encoding (QuantizationEncoding)
-
Return type:
rms_norm_key_cache()
max.nn.kernels.rms_norm_key_cache(kv_params, kv_collection, gamma, epsilon, layer_idx, total_seq_len, input_row_offsets, weight_offset, rms_norm_cols=None, multiply_before_cast=True, per_head_norm=True)
This function applies RMSNorm to the _new_ entries in the KVCache.
When per_head_norm=True (default), RMSNorm is applied separately to each head. In this mode, gamma should have size [head_dim] and normalization occurs across the head_dim dimensions within each head.
When per_head_norm=False, RMSNorm is applied per token across all heads. In this mode, gamma should have size [n_kv_heads * head_dim] and normalization occurs across all dimensions for each token.
The size of the gamma tensor determines how many dimensions will be normalized. If gamma’s size doesn’t match the expected size based on per_head_norm setting, rms_norm_cols must be explicitly specified to confirm the intention to normalize only a subset of dimensions.
Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input_row_offsets to do this bookkeeping.
-
Parameters:
-
- kv_params (KVCacheParams)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- gamma (TensorValue)
- epsilon (float | floating[Any])
- layer_idx (TensorValue)
- total_seq_len (Dim)
- input_row_offsets (TensorValue)
- weight_offset (float | floating[Any])
- rms_norm_cols (int | None)
- multiply_before_cast (bool)
- per_head_norm (bool)
-
Return type:
-
None
rms_norm_value_cache()
max.nn.kernels.rms_norm_value_cache(kv_params, kv_collection, gamma, epsilon, layer_idx, total_seq_len, input_row_offsets, weight_offset, rms_norm_cols=None, multiply_before_cast=True, per_head_norm=True)
Applies RMSNorm in place to the _new_ entries in the value cache.
Semantics match rms_norm_key_cache(), but updates the value tensor
for the layer instead of the key tensor.
-
Parameters:
-
- kv_params (KVCacheParams)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- gamma (TensorValue)
- epsilon (float | floating[Any])
- layer_idx (TensorValue)
- total_seq_len (Dim)
- input_row_offsets (TensorValue)
- weight_offset (float | floating[Any])
- rms_norm_cols (int | None)
- multiply_before_cast (bool)
- per_head_norm (bool)
-
Return type:
-
None
rope_ragged()
max.nn.kernels.rope_ragged(input, input_row_offsets, start_pos, freqs_cis, *, interleaved=True)
Applies RoPE to ragged input using the standard rope kernel.
-
Parameters:
-
- input (TensorValue)
- input_row_offsets (TensorValue)
- start_pos (TensorValue)
- freqs_cis (TensorValue)
- interleaved (bool)
-
Return type:
rope_ragged_with_position_ids()
max.nn.kernels.rope_ragged_with_position_ids(input, freqs_cis, position_ids, *, mrope_section=None, interleaved=True)
Applies RoPE using explicit position_ids (no KV cache coupling).
-
Parameters:
-
- input (TensorValue)
- freqs_cis (TensorValue)
- position_ids (TensorValue)
- mrope_section (list[int] | None)
- interleaved (bool)
-
Return type:
rope_split_store_ragged()
max.nn.kernels.rope_split_store_ragged(kv_params, qkv, input_row_offsets, freqs_cis, kv_collection, layer_idx, n_heads, interleaved=True, position_ids=None, mrope_section=None, fuse=True)
Apply rope to Q and K from flat QKV buffer, store K/V to cache.
Reads from a flat QKV matmul output, applies RoPE to Q and K regions, stores K/V to the paged KV cache, and writes roped Q to the output.
-
Parameters:
-
- kv_params (KVCacheParams) – KV cache parameters.
- qkv (TensorValue) – Flat QKV matmul output [total_seq_len, q_dim + k_dim + v_dim].
- input_row_offsets (TensorValue) – Ragged offsets [batch_size + 1].
- freqs_cis (TensorValue) – RoPE frequencies [max_seq_len, head_dim].
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache.
- layer_idx (TensorValue) – Layer index.
- n_heads (int) – Number of query attention heads.
- interleaved (bool) – Whether freqs_cis uses interleaved (re, im) format.
- position_ids (TensorValue | None) – Optional ragged 2D array of position IDs. If None,
defaults to cache_length + token_idx for each token. When
num_sections > 1,mrope_sectionmust be provided. Shape: [num_sections, total_seq_len]. - mrope_section (list[int] | None) – Optional list of ints indicating the section of the
head_dim to apply RoPE to. Must be used with
position_ids. - fuse (bool) – If True (default), emit a single fused custom op. If False, emit separate split, rope, and store ops for testing graph compiler fusion.
-
Returns:
-
Roped Q output [total_seq_len, n_heads * head_dim].
-
Return type:
scatter_nd_skip_oob_indices()
max.nn.kernels.scatter_nd_skip_oob_indices(input, updates, indices)
Creates a new symbolic tensor where the updates are scattered into input at specified indices.
This differs from scatter_nd in that it handles oob indices by skipping the update for that index. Oob indices are those which fall outside of the range [-dim, dim).
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to write elements to.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of elements to write to input.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A tensor of indices specifying where to write updates. Shape should be [num_updates, rank] for full indexing or [num_updates, k] for partial indexing where k < rank.
-
Returns:
-
A new symbolic tensor representing the result of the scatter_nd operation.
-
Return type:
scatter_set_constant()
max.nn.kernels.scatter_set_constant(data, indices, fill_val)
Scatters values into a tensor at specified indices.
-
Parameters:
-
- data (BufferValue | HasBufferValue)
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- fill_val (float)
-
Return type:
-
None
sgmv_kernel()
max.nn.kernels.sgmv_kernel(input, lora, lora_ids, lora_ranks, input_row_offsets, max_lora_seq_len, lora_end_idx=None, bias=None)
Performs the SGMV kernel for LoRA. This is LoRA agnostic, meaning that we can perform LoRA A or B from this kernel call.
-
Parameters:
-
- input (TensorValue) – The input tensor.
- lora (TensorValue) – The LoRA tensor.
- lora_ids (TensorValue) – Ids of the LoRAs used for each sequence
- lora_ranks (TensorValue) – The ranks of the LoRAs in the batch.
- input_row_offsets (TensorValue) – The sequence offsets that use LoRA
- max_lora_seq_len (int) – The maximum sequence length of any given LoRA in the batch
- bias (TensorValue | None) – The LoRA bias
- lora_end_idx (TensorValue | None)
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
sgmv_lora_kernel()
max.nn.kernels.sgmv_lora_kernel(input, lora_a, lora_b, lora_ids, lora_ranks, grouped_row_offsets, lora_end_idx, max_lora_seq_len, bias=None)
Computes the SGMV LoRA kernel for some number of LoRAs A and B given the input.
out = Wx + xAB
- SGMV can be explained by two independent kernels:
- shrink -> shrinks high-dimensional tensor to low-rank tensor
- expand -> expands low-rank tensor to high-dimensional tensor
where v = [0, …] and y = (some output tensor)
- SGMV-shrink:
- v += xA
- SGMV-expand:
- y += vB
-
Parameters:
-
- input (TensorValue) – The input tensor
- lora_a (TensorValue) – The LoRA tensor for A
- lora_b (TensorValue) – The LoRA tensor for B
- lora_ids (TensorValue) – Ids of the LoRAs used for each sequence
- lora_ranks (TensorValue) – The ranks of the LoRAs in the batch.
- grouped_row_offsets (TensorValue) – The grouped sequence offsets that use LoRA
- max_lora_seq_len (int) – The maximum sequence length of any given LoRA in the batch
- bias (TensorValue | None) – The LoRA bias
- lora_end_idx (TensorValue)
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
sgmv_lora_qkv_shrink()
max.nn.kernels.sgmv_lora_qkv_shrink(input, lora_a, lora_ids, lora_grouped_offsets, lora_end_idx, max_lora_seq_len, max_rank)
LoRA shrink grouped matmul with planar Q/K/V output.
Performs the LoRA ‘shrink’ operation for routed tokens using SGMV (segmented grouped matrix-vector multiplication). Computes [M, K] @ [G, 3*rank, K]^T per active LoRA adapter, then permutes the flat [M, 3*rank] result into a planar layout [3, M, rank] representing separate Q, K, V projections.
-
Parameters:
-
- input (TensorValue) – Routed activation matrix with shape (M, K), where M is the total number of tokens and K is the hidden dimension.
- lora_a (TensorValue) – Shrink weights for all LoRA adapters, shape (G, 3*rank, K) where G is the number of adapters and rank is the LoRA rank.
- lora_ids (TensorValue) – Expert/adapter indices for each active group, shape (num_active,). Values in range [0, G). May use -1 to indicate inactive slots.
- lora_grouped_offsets (TensorValue) – Inclusive prefix sums of tokens per active adapter, shape (num_active + 1,). Defines per-adapter [start, end) ranges in input. Must be non-decreasing with offsets[0] == 0.
- max_lora_seq_len (int) – Upper bound on tokens for any active adapter. Used for kernel tuning and memory allocation.
- max_rank (int) – The maximum LoRA rank, determines output shape.
- lora_end_idx (TensorValue)
-
Returns:
-
Output tensor with planar Q/K/V layout, shape (3, M, max_rank).
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
sgmv_qkv_lora_kernel()
max.nn.kernels.sgmv_qkv_lora_kernel(input, lora_a, lora_b_q, lora_b_kv, lora_ids, lora_ranks, input_row_offsets, lora_grouped_offsets, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv, kv_collection, kv_params, layer_idx, max_lora_seq_len, max_rank, bias=None)
Computes the SGMV QKV LoRA kernel for Q, K, V projections with LoRA.
-
Parameters:
-
- input (TensorValue) – The input tensor.
- lora_a (TensorValue) – The LoRA A tensor.
- lora_b_q (TensorValue) – The LoRA B tensor for Q projection.
- lora_b_kv (TensorValue) – The LoRA B tensor for K and V projections (stacked).
- lora_ids (TensorValue) – IDs of the LoRAs used for each sequence.
- lora_ranks (TensorValue) – The ranks of the LoRAs in the batch.
- input_row_offsets (TensorValue) – The sequence offsets that use LoRA.
- lora_grouped_offsets (TensorValue) – Grouped offsets for LoRA sequences.
- lora_end_idx (TensorValue) – End index of LoRA tokens in the batch.
- batch_seq_len (TensorValue) – Total sequence length of the batch.
- lora_ids_kv (TensorValue) – LoRA IDs for KV projections (with offset for V portion).
- lora_grouped_offsets_kv (TensorValue) – Grouped offsets for KV LoRA sequences.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The KV cache.
- kv_params (KVCacheParams) – The key-value cache configuration parameters.
- layer_idx (TensorValue) – The layer index to retrieve the KV cache.
- max_lora_seq_len (int) – The maximum sequence length of any given LoRA in the batch.
- max_rank (int) – The maximum rank for the LoRAs.
- bias (TensorValue | None) – Optional LoRA bias.
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
sleep()
max.nn.kernels.sleep(duration_sec, device_ref)
Sleep for the given duration in seconds.
This kernel is supported on CPUs and GPUs. However, the timing may be completely inaccurate on AMD GPUs due to limitation of current time.sleep(…) impl.
-
Parameters:
-
- duration_sec (BufferValue) – The duration to sleep in seconds.
- device_ref (DeviceRef)
-
Return type:
-
None
sliced_add()
max.nn.kernels.sliced_add(x, y, lora_end_idx)
Adds tensors x and y element-wise for rows < lora_end_idx, otherwise copies x.
This is used for LoRA where only some sequences have LoRA applied. For rows in [0, lora_end_idx): c = x + y For rows in [lora_end_idx, batch_seq_len): c = x
-
Parameters:
-
- x (TensorValue) – First input tensor.
- y (TensorValue) – Second input tensor.
- lora_end_idx (TensorValue) – End index of LoRA token portion (rows to apply add).
-
Return type:
spatial_merge()
max.nn.kernels.spatial_merge(input, grid_thw, hidden_size, merge_size)
Performs spatial merge operation on ragged input tensors.
This operation merges spatial dimensions of input patches according to the grid dimensions specified in grid_thw.
-
Parameters:
-
- input (TensorValue) – Input tensor of shape [total_patches_in_grid, hidden_size]
- grid_thw (TensorValue) – Grid dimensions tensor of shape [batch_size, 3] containing
[t, h, w] for each batch item, where:
- t: temporal/frame dimension
- h: height dimension
- w: width dimension
- hidden_size (int) – Hidden dimension size
- merge_size (int) – Size of spatial merge blocks (typically 2)
-
Returns:
-
Output tensor of shape [total_patches_in_grid, hidden_size]
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Return type:
store_k_cache_padded()
max.nn.kernels.store_k_cache_padded(kv_collection, x_k, valid_lengths, layer_idx)
Stores the key tensor into the paged KV cache for padded inputs.
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
- x_k (TensorValue) – The key tensor of rank 4 containing the new key projections.
- valid_lengths (TensorValue) – Buffer of shape
[batch](dtypeuint32) indicating the actual (non-padded) sequence length for each batch element. - layer_idx (TensorValue) – The scalar layer index (dtype
uint32) identifying which transformer layer’s cache to update.
-
Return type:
-
None
store_k_cache_ragged()
max.nn.kernels.store_k_cache_ragged(kv_collection, x_k, input_row_offsets, layer_idx)
Stores the key tensor into the paged KV cache for ragged inputs.
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
- x_k (TensorValue) – The key tensor of rank 3 containing the new key projections.
- input_row_offsets (TensorValue) – Ragged tensor row offsets of shape
[batch + 1]indicating where each sequence starts and ends. Must have dtypeuint32. - layer_idx (TensorValue) – The scalar layer index (dtype
uint32) identifying which transformer layer’s cache to update.
-
Return type:
-
None
store_k_scale_cache_ragged()
max.nn.kernels.store_k_scale_cache_ragged(kv_collection, x_k_scale, input_row_offsets, layer_idx, quantization_granularity)
Store key scale tensor into the paged KV cache.
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- x_k_scale (TensorValue)
- input_row_offsets (TensorValue)
- layer_idx (TensorValue)
- quantization_granularity (int)
-
Return type:
-
None
store_v_cache_padded()
max.nn.kernels.store_v_cache_padded(kv_collection, x_v, valid_lengths, layer_idx)
Stores the value tensor into the paged KV cache for padded inputs.
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
- x_v (TensorValue) – The value tensor of rank 4 containing the new value projections.
- valid_lengths (TensorValue) – Buffer of shape
[batch](dtypeuint32) indicating the actual (non-padded) sequence length for each batch element. - layer_idx (TensorValue) – The scalar layer index (dtype
uint32) identifying which transformer layer’s cache to update.
-
Return type:
-
None
store_v_cache_ragged()
max.nn.kernels.store_v_cache_ragged(kv_collection, x_v, input_row_offsets, layer_idx)
Stores the value tensor into the paged KV cache for ragged inputs.
-
Parameters:
-
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – The paged KV cache collection to write into.
- x_v (TensorValue) – The value tensor of rank 3 containing the new value projections.
- input_row_offsets (TensorValue) – Ragged tensor row offsets of shape
[batch + 1]indicating where each sequence starts and ends. Must have dtypeuint32. - layer_idx (TensorValue) – The scalar layer index (dtype
uint32) identifying which transformer layer’s cache to update.
-
Return type:
-
None
topk_fused_sampling()
max.nn.kernels.topk_fused_sampling(logits, top_k, *, temperature=1.0, max_k=None, min_top_p=None, top_p=1.0, min_p=None, seed=0)
Performs top-k sampling with temperature scaling.
-
Parameters:
-
- logits (TensorValue) – Input logits tensor of shape [batch_size, vocab_size].
- top_k (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Number of top tokens to consider for sampling. Can be a scalar (which will be expanded to batch_size) or a tensor of shape [batch_size].
- temperature (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Temperature for scaling logits before sampling.
- max_k (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Maximum value of k across the batch. Required when top_k is a tensor.
- top_p (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Top-p (nucleus) sampling threshold. Can be a scalar or tensor.
- min_p (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Per-row min_p probability filtering threshold of shape
[batch_size]. Tokens with probability below
min_p * max_probare zeroed before sampling. - seed (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Seed for the random number generator. Can be a scalar or tensor.
- min_top_p (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None)
-
Returns:
-
Sampled tokens tensor of shape [batch_size, 1].
-
Raises:
-
ValueError – If input validation fails.
-
Return type:
tpool_patch_merger()
max.nn.kernels.tpool_patch_merger(input, grid_thws, kH, kW, max_h, max_w)
Performs temporal pooling patch merger on ragged video tokens.
For each video in the batch, averages the input across the temporal (T) dimension and rearranges the result according to the spatial merge kernel (kH, kW). Each video’s T*H*W input tokens are reduced to H*W output tokens. All videos are concatenated contiguously in the output.
-
Parameters:
-
- input (TensorValue) – Input tensor of shape
[total_input_tokens, D]wheretotal_input_tokens = sum(T_i * H_i * W_i)over all videos. - grid_thws (TensorValue) – Grid dimensions tensor of shape
[n_videos, 3]with(T, H, W)per video. Must have dtypeint64. - kH (int) – Merge kernel height.
- kW (int) – Merge kernel width.
- max_h (int | TensorValue) – Maximum
Hacross all videos in the batch (for grid sizing). May be a Python int (baked as a graph constant) or aTensorValuecomputed at runtime (e.g. viaops.max). - max_w (int | TensorValue) – Maximum
Wacross all videos in the batch (for grid sizing). May be a Python int or aTensorValue.
- input (TensorValue) – Input tensor of shape
-
Returns:
-
Output tensor of shape
[sum(H_i * W_i), D]. -
Raises:
-
ValueError – On invalid input shapes or dtypes.
-
Return type:
unfused_qkv_ragged_matmul_gguf_quantized()
max.nn.kernels.unfused_qkv_ragged_matmul_gguf_quantized(kv_params, input, input_row_offsets, n_heads, q_weight, k_weight, v_weight, quantization_encoding_q, quantization_encoding_k, quantization_encoding_v, kv_collection, layer_idx)
Computes fused query, key, and value projections with ragged input and
quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged
tensor.
input_row_offsets indicates where each batch starts and ends in input
-
Raises:
-
ValueError – on input shapes/dtypes that are invalid for the kernel.
-
Parameters:
-
- kv_params (KVCacheParams)
- input (TensorValue)
- input_row_offsets (TensorValue)
- n_heads (int)
- q_weight (TensorValue)
- k_weight (TensorValue)
- v_weight (TensorValue)
- quantization_encoding_q (QuantizationEncoding)
- quantization_encoding_k (QuantizationEncoding)
- quantization_encoding_v (QuantizationEncoding)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- layer_idx (TensorValue)
-
Return type:
update_frequency_data()
max.nn.kernels.update_frequency_data(frequency_data, frequency_offsets, tokens)
Updates the frequency data.
-
Parameters:
-
- frequency_data (BufferValue) – 2d tensor of shape [unique_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token.
- frequency_offsets (TensorValue) – 1d tensor of shape [batch_size + 1], indicating start of each sequence’s data.
- tokens (TensorValue) – The tokens to update the frequency data with.
-
Return type:
-
None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!