v3.9 update (#2213)

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-02 23:10:16 -07:00
committed by GitHub
parent 6f4921858b
commit 79fc51f4b8
72 changed files with 19875 additions and 459 deletions

View File

@@ -13,11 +13,16 @@
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu).
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
- Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture.
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.

View File

@@ -50,11 +50,16 @@ architecture.
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu).
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
- Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture.
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.

View File

@@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{
void initialize(Options const& options) {
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
@@ -429,7 +429,7 @@ void initialize(Options const& options) {
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);

View File

@@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions {
void initialize(Options const& options) {
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
@@ -347,7 +347,7 @@ void initialize(Options const& options) {
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);

View File

@@ -288,7 +288,7 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
void initialize(MixedDtypeOptions const& options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
@@ -313,7 +313,7 @@ void initialize(MixedDtypeOptions const& options) {
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);

View File

@@ -208,20 +208,6 @@ bool initialize_tensor(
return true;
}
template <typename Element>
bool initialize_quant_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed = 2023) {
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <class Element>
bool initialize_scale(
cutlass::DeviceAllocation<Element>& block,
@@ -232,10 +218,8 @@ bool initialize_scale(
float scope_max = 1.0f, scope_min = 1.0f;
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
const float max_dequant_val = 4.f;
const float min_dequant_val = 0.5f;
scope_max = max_dequant_val / elt_max_f;
scope_min = min_dequant_val / elt_max_f;
scope_max = 2.f;
scope_min = 0.1f;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));

View File

@@ -120,8 +120,7 @@
#include "helper.h"
// Distributed GEMM helpers
#include "util/benchmark.h"
#include "util/device_copy.h"
#include "dist_gemm_helpers.h"
using namespace cute;

View File

@@ -1,84 +0,0 @@
/******************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/*! \file
\brief generic device-to-device data movement kernel based for CuTe tensors.
NOTE: this kernel assigns one element copy to every thread, and is by no means
an efficient way of copying tensors. It should only be used for convenience in
reference checks.
*/
#pragma once
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cuda_host_adapter.hpp"
namespace cutlass {
template <typename TensorSource, typename TensorDestination>
void device_copy(TensorSource tensor_source,
TensorDestination tensor_destination,
cudaStream_t stream);
template <typename TensorSource, typename TensorDestination>
__global__ void device_copy_kernel(TensorSource const tensor_source,
TensorDestination tensor_destination) {
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
using ElementSrc = typename TensorSource::value_type;
using ElementDst = typename TensorDestination::value_type;
NumericConverter<ElementDst, ElementSrc> converter;
if (linear_idx < size(tensor_source)) {
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
}
}
template <typename TensorSource, typename TensorDestination>
void device_copy(TensorSource tensor_source,
TensorDestination tensor_destination,
cudaStream_t stream) {
assert(tensor_source.size() == tensor_destination.size());
auto numel = tensor_source.size();
static constexpr int NumThreads = 128;
auto grid_size = cute::ceil_div(numel, NumThreads);
dim3 grid(grid_size);
dim3 block(NumThreads);
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
}
} //namespace cutlass

View File

@@ -374,7 +374,7 @@ void allocate(Options const& options) {
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
@@ -510,7 +510,7 @@ void initialize(Options &options) {
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_B, seed + 2022);
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
@@ -519,13 +519,13 @@ void initialize(Options &options) {
for (int32_t i = 0; i < options.groups; ++i) {
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream);
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream);
}
problem_sizes.reset(options.groups);
@@ -619,7 +619,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
@@ -676,6 +676,7 @@ bool verify(Options const& options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
// we don't swap and transpose in the verify so revert the problem shape.
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
@@ -712,7 +713,7 @@ bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
}
}
return passed;

View File

@@ -341,7 +341,7 @@ void allocate(Options const& options) {
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
@@ -479,7 +479,7 @@ void initialize(Options& options) {
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_B, seed + 2022);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
@@ -565,7 +565,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
@@ -617,6 +617,7 @@ bool verify(Options const& options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
// we don't swap and transpose in the verify so revert the problem shape.
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
@@ -630,11 +631,11 @@ bool verify(Options const& options) {
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
//
// Compute reference output
@@ -659,7 +660,7 @@ bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
}
}
return passed;

View File

@@ -282,7 +282,7 @@ void allocate(Options const& options) {
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
@@ -418,7 +418,7 @@ void initialize(Options &options) {
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_B, seed + 2022);
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
@@ -485,7 +485,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
@@ -542,6 +542,7 @@ bool verify(Options const& options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
// we don't swap and transpose in the verify so revert the problem shape.
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
@@ -555,11 +556,11 @@ bool verify(Options const& options) {
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
//
// Compute reference output
@@ -584,7 +585,7 @@ bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
}
}
return passed;

View File

@@ -50,6 +50,7 @@ set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10)
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling
cutlass_example_add_executable(
69_hopper_mixed_dtype_grouped_gemm
@@ -69,6 +70,7 @@ cutlass_example_add_executable(
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
)
cutlass_example_add_executable(
@@ -89,6 +91,7 @@ cutlass_example_add_executable(
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
)
cutlass_example_add_executable(
@@ -109,4 +112,5 @@ cutlass_example_add_executable(
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
)

View File

@@ -7,11 +7,11 @@ This example shows how to perform Grouped GEMMs on Hopper when A and B have diff
- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B.
- if scales and zero-points are included, also pass the array of their strides in the arguments.
Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs.
Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling.
## Upcoming features
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed.
## Copyright

View File

@@ -58,6 +58,7 @@ public:
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
cmd.get_cmd_line_argument("c", c);
MixedDtypeOptions::parse(argc, args);
@@ -71,6 +72,7 @@ public:
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --c=<int> Sets the chunk size for scaling the quantized weights\n"
<< " --groups=<int> Sets the number of individual GEMM problems\n"
<< " --mode=<int> The mode to run the gemm\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
@@ -183,11 +185,6 @@ void grouped_mixed_dtype_profiling(
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Sizes, Alpha, Beta\n";
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n';
}
std::cout << " Groups : " << options.groups << '\n'
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
<< " GFLOPS : " << result.gflops << '\n';

View File

@@ -0,0 +1,832 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the
NVIDIA Blackwell Architecture.
*/
#include <iostream>
#include <random>
#include <regex>
#include <cmath>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "reference/fmha_mla_reference.hpp"
#include "reference/reference_abs_error.hpp"
#include "device/sm100_mla.hpp"
#include "kernel/sm100_mla_tile_scheduler.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
using namespace cute;
using namespace cutlass::fmha::kernel;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle {
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help = false;
bool error = false;
int b = 1;
int k = 256;
int split_kv = -1; // number of split along k dim.
bool is_var_split_kv = false;
int max_split_kv = 16;
int page = -1;
float spread = 0.2f;
int iterations = 3;
bool verify = false;
bool verbose = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
InitStyle init_style_c = InitStyle::kRandom;
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
std::string s;
cmd.get_cmd_line_argument(name, s, s);
if (s.empty()) {
dst = src;
}
else {
if (s == "r") {
dst = InitStyle::kRandom;
}
else if (s == "1") {
dst = InitStyle::kOne;
}
else if (s == "d") {
dst = InitStyle::kLinearStride1;
}
else if (s == "s") {
dst = InitStyle::kLinearStride128;
}
else if (s == "n") {
dst = InitStyle::kNone;
}
else {
std::cout << "Error: " << s << " is not a valid input type.\n";
std::exit(-1);
}
}
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
Options defaults;
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("k", k, -1);
if (k == -1) k = defaults.k;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
cmd.get_cmd_line_argument("page", page, defaults.page);
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
if (page == -1) {
is_var_split_kv = false;
}
cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv);
if (is_var_split_kv == true) {
split_kv = max_split_kv;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c);
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c);
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "77_blackwell_mla\n\n"
<< " This example showcases the use of CUTLASS for fused multi-head latent\n"
<< " attention kernels targeting NVIDIA's Blackwell architecture.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the B extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --page=<int> Enables paging and sets the page size\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --spread=<float> Relative spread away from K for paging\n"
<< " --split_kv=<int> Split KV factor\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
void initialize_block(
DeviceAllocation<Element>& block,
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
switch (init_style) {
case InitStyle::kOne: {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) 1, (Element) 1);
break;
}
case InitStyle::kRandom: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 1);
break;
}
case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {
for (int j = 0; j < 128; j++) {
data[j + 128*i] = static_cast<Element>((double) (j % 4));
}
}
block.copy_from_host(data.data(), data.size());
break;
}
case InitStyle::kLinearStride128: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 64; i ++) {
for (int j = 0; j < 64; j++) {
data[j + 64*i] = static_cast<Element>((double) (i % 9));
}
}
block.copy_from_host(data.data(), data.size());
break;
}
case InitStyle::kNone: {
break;
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleResult {
bool passed = false;
bool verified = false;
float runtime_ms = 0;
double tflops_tc_s = 0;
double tbytes_s = 0;
size_t smem_size = 0;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
template<bool v>
struct IsPersistent {
static const bool value = v;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template<
class TileShape,
class PersistenceOption = IsPersistent<true>
>
struct Runner {
#ifdef FP8
using Element = cutlass::float_e4m3_t;
#elif FP16
using Element = cutlass::half_t;
#else
#error "Must either define FP8 or FP16"
#endif
using ElementAcc = float;
using ElementOut = cutlass::half_t;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;
// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B
using TileScheduler = std::conditional_t<
PersistenceOption::value,
Sm100MlaPersistentTileScheduler,
Sm100MlaIndividualTileScheduler
>;
using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler
>;
using Operation = cutlass::fmha::device::MLA<Kernel>;
//
// Data members
//
/// Initialization
StrideQ stride_Q_latent;
StrideK stride_C_latent;
StrideQ stride_Q_rope;
StrideK stride_K_rope;
StrideO stride_O;
StrideLSE stride_LSE;
StrideLSE stride_PT;
uint64_t seed = 0;
int page_size = -1;
int page_count = -1;
// We allocate Q and C as first latent, then rope
// This means that we offset the pointer by HeadDim_latent to get the rope
// portion
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_C;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<int> block_seq;
DeviceAllocation<int> block_PT;
DeviceAllocation<int> block_split_kv;
DeviceAllocation<int> block_accum_split_len;
DeviceAllocation<ElementAcc> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAcc> block_ref_LSE;
ElementAcc scale;
//
// Methods
//
bool verify(const ProblemShape& problem_shape) {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
int page_K = K;
int page_B = B;
if (block_PT.get() != nullptr) {
page_K = page_size;
page_B = page_count;
}
Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()),
cute::make_tuple(H, D_latent, B),
stride_Q_latent);
Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent),
cute::make_tuple(H, D_rope, B),
stride_Q_rope);
Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()),
cute::make_tuple(page_K, D_latent, page_B),
stride_C_latent);
Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent),
cute::make_tuple(page_K, D_rope, page_B),
stride_K_rope);
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
cute::make_tuple(H, D_latent, B),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
cute::make_tuple(H, B),
stride_LSE);
Tensor mSeq = make_tensor(make_gmem_ptr(static_cast<int*>(block_seq.get())), make_shape(B));
Tensor mPT = make_tensor(make_gmem_ptr(static_cast<int*>(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT);
fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
#ifdef B2B
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
#else
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
#endif
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
std::cerr << "failed O: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
bool passed_LSE = true;
#ifndef B2B
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
#endif
return passed_O && passed_LSE;
}
ProblemShape initialize(const Options& options) {
auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int D_non_latent = 128;
scale = static_cast<decltype(scale)>(1.0 / sqrt(1.0 * (D_non_latent + D_rope)));
// Shape (H, D, B)
stride_Q_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
stride_Q_rope = stride_Q_latent;
stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
stride_LSE = cute::make_tuple(_1{}, 0 + H);
block_Q.reset(static_cast<size_t>(options.b) * H * (D_latent + D_rope));
block_O.reset(static_cast<size_t>(options.b) * H * D_latent);
block_LSE.reset(static_cast<size_t>(options.b) * H);
block_ref_O.reset(static_cast<size_t>(options.b) * H * D_latent);
block_ref_LSE.reset(static_cast<size_t>(options.b) * H);
if (options.page == -1) {
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(options.k) * (D_latent + D_rope));
stride_K_rope = stride_C_latent;
block_C.reset(static_cast<size_t>(options.b) * options.k * (D_latent + D_rope));
}
else {
float spread = options.spread;
int max_K = static_cast<int>((1 + spread) * K);
int min_K = static_cast<int>((1 - spread) * K);
page_size = options.page;
page_count = B * ceil_div(max_K, page_size);
stride_PT = cute::make_stride(_1{}, page_count);
std::vector<int> host_seq(B);
std::vector<int> host_PT(page_count * B);
for (int i = 0; i < B; i++) {
int seq = min_K + rand() % (max_K - min_K + 1);
host_seq[i] = seq;
for (int j = 0; j < ceil_div(seq, page_size); j++) {
host_PT[page_count * i + j] = i + j * B;
}
}
block_seq.reset(host_seq.size());
block_seq.copy_from_host(host_seq.data(), host_seq.size());
block_PT.reset(host_PT.size());
block_PT.copy_from_host(host_PT.data(), host_PT.size());
get<1>(problem_shape) = max_K;
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, page_size * static_cast<int64_t>((D_latent + D_rope)));
stride_K_rope = stride_C_latent;
block_C.reset(page_count * page_size * static_cast<int64_t>((D_latent + D_rope)));
if (options.is_var_split_kv == true) {
std::vector<int> host_split_kv(B);
for(int i = 0; i < B; ++i) {
auto len = host_seq[i];
int split = ceil_div(options.max_split_kv, ceil_div(max_K, len));
host_split_kv[i] = split;
}
block_split_kv.reset(B);
block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size());
}
}
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_C, seed + 2022, options.init_style_c);
return problem_shape;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShape problem_shape = initialize(options);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
typename Operation::Arguments arguments{
problem_shape,
{ scale,
block_Q.get(), stride_Q_latent,
block_Q.get() + D_latent, stride_Q_rope,
block_C.get(), stride_C_latent,
block_C.get() + D_latent, stride_K_rope,
block_seq.get(),
block_PT.get(), stride_PT,
page_count, page_size},
{ block_O.get(),
stride_O,
block_LSE.get(),
stride_LSE},
hw_info,
options.split_kv,
options.is_var_split_kv ? block_split_kv.get() : nullptr
};
if (options.split_kv < 0 && !options.is_var_split_kv) {
Operation::set_split_kv(arguments);
}
Operation op;
ExampleResult example_result;
example_result.smem_size = Operation::Kernel::SharedStorageSize;
size_t workspace_size = 0;
workspace_size = Operation::get_workspace_size(arguments);
DeviceAllocation<uint8_t> workspace(workspace_size);
cutlass::Status status = cutlass::Status::kSuccess;
status = op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
status = op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
// Run
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return example_result;
}
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result = cudaEventCreate(&event);
if (result != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
}
// Record an event at the start of a series of GEMMs
result = cudaEventRecord(events[0]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
for (int i = 0; i < options.iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result = cudaEventRecord(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
// Wait for work on the device to complete.
result = cudaEventSynchronize(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
runtime_ms /= static_cast<float>(options.iterations);
double flops = 1.0;
flops *= B;
flops *= K;
flops *= H;
flops *= 2.0;
flops *= (2.0 * D_latent + D_rope);
double bytes_q = sizeof(Element);
bytes_q *= B;
bytes_q *= H;
bytes_q *= (D_latent + D_rope);
double bytes_c = sizeof(Element);
bytes_c *= B;
bytes_c *= options.k; // K may be max_K here
bytes_c *= (D_latent + D_rope);
double bytes_o = sizeof(ElementOut);
bytes_o *= B;
bytes_o *= H;
bytes_o *= D_latent;
double bytes = bytes_q + bytes_c + bytes_o;
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
example_result.tbytes_s = tbytes_s;
example_result.runtime_ms = runtime_ms;
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return example_result;
}
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape);
if (passed) example_result.verified = true;
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
return example_result;
}
example_result.passed = true;
return example_result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
if (verbose) {
std::cout << " t=" << result.runtime_ms * 1e3 << " us, "
"smem=" << result.smem_size << "b" << std::endl;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, const char* name, auto... kernel_options) {
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
return;
}
Runner<decltype(shape), decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using NumHeads = _128;
using HeadDimLatent = _512;
using HeadDim = Shape<HeadDimLatent, _64>;
std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " ";
std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " ";
std::cout << "Q 1 K " << options.k << " Gen None ";
std::cout << "Split " << options.split_kv << " Gen None ";
std::cout << "#SM " << hw_info.sm_count << std::endl;
using Blocking = _128;
std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{});
std::string individual = " individual";
std::string persistent = " persistent";
#if FP8
name += " fp8";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
#elif FP16
name += " fp16";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
#endif
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_single(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
std::cout
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
if (options.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
else {
hw_info.sm_count = options.sm_count;
}
run_mla(options, hw_info);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
auto arg = full_arguments[i];
size_t eq_pos = arg.find('=');
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
for (;;) {
size_t comma_pos = rest.find(',');
std::string current = rest.substr(0, comma_pos);
full_arguments[i] = prefix + current;
std::vector<const char*> next_args;
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
main(argc, next_args.data());
if (comma_pos == std::string::npos) break;
rest = rest.substr(comma_pos+1);
}
recursed = true;
break;
}
}
if (! recursed) {
main_single(argc, args);
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -35,6 +35,10 @@ set_property(
SOURCE 77_blackwell_fmha_gen.cu
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
set_property(
SOURCE 77_blackwell_mla.cu
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
@@ -48,58 +52,69 @@ set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")))
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
77_blackwell_fmha_fp8
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8)
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
cutlass_example_add_executable(
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
cutlass_example_add_executable(
77_blackwell_fmha_fp16
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
foreach(PREC fp8 fp16)
string(TOUPPER "${PREC}" PREC_MACRO)
cutlass_example_add_executable(
77_blackwell_fmha_gen_fp16
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
endif()
cutlass_example_add_executable(
77_blackwell_fmha_${PREC}
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
cutlass_example_add_executable(
77_blackwell_fmha_gen_${PREC}
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
cutlass_example_add_executable(
77_blackwell_mla_2sm_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_mla_2sm_cpasync_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_mla_b2b_2sm_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
endforeach()
endif()

View File

@@ -22,6 +22,24 @@ The `apply_mask` function is called with the accumulator of the first GEMM and t
It is well-suited for applying masks or activations.
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
# MLA Inference for Blackwell
This sample provides code for fused multi-head latent attention inference in
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
It supports fp16, bf16, and fp8 input and output types.
To accomodate the large output accumulator due to the large latent head dimension,
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
for support of any power-of-two page size less than or equal to 128.
With paging, the code also supports variable sequence length.
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel.
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

View File

@@ -0,0 +1,92 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cuda_runtime.h>
namespace cutlass::fmha {
struct Pow2 {
int n;
int log2_n;
explicit CUTE_DEVICE Pow2(int n) : n(n) {
#ifdef __CUDA_ARCH__
log2_n = __ffs(n) - 1;
#endif
}
template<class T>
CUTE_HOST_DEVICE T operator *(T const& b) const {
return n * b;
}
template<int N>
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
if constexpr (N & (N - 1) == 0) {
return Pow2{n * N};
}
return n * N;
}
};
template<class T>
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
return a >> b.log2_n;
}
template<class T>
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
return a & (b.n - 1);
}
template<class T>
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
return a < b.n;
}
CUTE_HOST_DEVICE void print(Pow2 const& a) {
printf("2^%d", a.log2_n);
}
} // end namespace cutlass::fmha
namespace cute {
template <>
struct is_integral<cutlass::fmha::Pow2> : true_type {};
} // end namespace cute

View File

@@ -0,0 +1,357 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
using namespace cute;
using namespace cutlass::fmha::kernel;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class Kernel_
>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
/// Argument structure: User API
using KernelArguments = typename Kernel::Arguments;
using ReductionArguments = typename ReductionKernel::Arguments;
using Arguments = KernelArguments;
/// Argument structure: Kernel API
using KernelParams = typename Kernel::Params;
using ReductionParams = typename ReductionKernel::Params;
struct Params {
KernelParams fmha_params;
ReductionParams reduction_params;
};
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
static ReductionArguments to_reduction_args(Arguments const& args) {
auto [H, K, D, B] = args.problem_shape;
return ReductionArguments{
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
args.ptr_split_kv, Kernel::TileShapeS::value
};
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
static void set_split_kv (KernelArguments& args) {
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
int sm_count = args.hw_info.sm_count;
int max_splits = ceil_div(K, 128);
int sms_per_batch = max(1, sm_count / B);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (! Kernel::can_implement(args)) {
return Status::kInvalid;
}
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
return Status::kInvalid;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
return workspace_bytes;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
if (status != Status::kSuccess) {
return status;
}
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {kernel_params, reduction_params};
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {fmha_params, reduction_params};
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params.fmha_params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess != result or Status::kSuccess != launch_result) {
//return Status::kSuccess;
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
else {
return Status::kSuccess;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
using Params = Arguments;
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
args.ptr_split_kv, args.tile_shape_s};
}
static size_t get_workspace_size(Arguments const& /*args*/) {
return 0;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return dim3(kNumHeads, 1, params.num_batches);
}
static dim3 get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
static bool can_implement(Arguments const& args) {
if (args.num_batches <= 0) return false;
if (args.split_kv <= 0) return false;
return true;
}
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
ElementAcc local_lse[kNLsePerThread];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
}
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
if (split < local_split_kv) {
sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
}
__syncthreads();
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
ElementAcc local_val[Elements] = {0};
for (int split = 0; split < local_split_kv; ++split) {
ElementAcc lse_scale = sLseScale[split];
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
}
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
}
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
}
}
};
} // namespace cutlass::fmha::kernel

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaIndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler(Params const&) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = size<0>(cluster_shape);
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
num_blocks *= split_kv; /* Maximum Split KV*/
return Params {
num_blocks,
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@@ -0,0 +1,206 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorSeq,
class TensorPageTable,
class TensorQL,
class TensorQR,
class TensorCL,
class TensorKR,
class TensorO,
class TensorLSE,
class Scale
>
void __global__ fmha_mla_reference_kernel(
ProblemShape problem_shape,
TensorSeq mSeq, TensorPageTable mPT,
TensorQL mQL, TensorQR mQR,
TensorCL mCL, TensorKR mKR,
TensorO mO, TensorLSE mLSE,
Scale softmax_scale) {
using namespace cute;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
using Element = typename TensorO::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ ElementAcc mS[];
// ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) {
if (mSeq.data() != nullptr) {
K = mSeq(idx_B);
}
for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
ElementAcc acc = 0;
for (int idx_D = 0; idx_D < D_latent; idx_D++) {
int page_idx_K = idx_K;
int page_idx_B = idx_B;
if (mPT.data() != nullptr) {
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
page_idx_K = idx_K % size<0>(mCL);
}
ElementAcc eQ = mQL(idx_H, idx_D, idx_B);
ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B);
acc += eQ * eK;
}
for (int idx_D = 0; idx_D < D_rope; idx_D++) {
int page_idx_K = idx_K;
int page_idx_B = idx_B;
if (mPT.data() != nullptr) {
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
page_idx_K = idx_K % size<0>(mCL);
}
ElementAcc eQ = mQR(idx_H, idx_D, idx_B);
ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B);
acc += eQ * eK;
}
mS[idx_K] = acc;
}
__syncthreads();
ElementAcc maxS = -std::numeric_limits<ElementAcc>::infinity();
for (int idx_K = 0; idx_K < K; idx_K++) {
maxS = std::max<ElementAcc>(maxS, mS[idx_K]);
}
if (maxS == -std::numeric_limits<ElementAcc>::infinity()) maxS = 0;
__syncthreads();
#ifndef B2B
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
}
#endif
__syncthreads();
ElementAcc sum = 0;
for (int idx_K = 0; idx_K < K; idx_K++) {
sum += mS[idx_K];
}
ElementAcc o_scale = 1.0f / sum;
#ifdef B2B
o_scale = 1.0;
#endif
for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_K = 0; idx_K < K; idx_K++) {
int page_idx_K = idx_K;
int page_idx_B = idx_B;
if (mPT.data() != nullptr) {
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
page_idx_K = idx_K % size<0>(mCL);
}
ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B);
ElementAcc eK = static_cast<Element>(mS[idx_K]);
acc += eK * eV;
}
mO(idx_H, idx_D, idx_B) = static_cast<typename TensorO::value_type>(acc * o_scale);
}
if (threadIdx.x == 0) {
mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorSeq,
class TensorPageTable,
class TensorQL,
class TensorQR,
class TensorCL,
class TensorKR,
class TensorO,
class TensorLSE,
class Scale
>
void fmha_mla_reference(
ProblemShape problem_shape,
TensorSeq mSeq, TensorPageTable mPT,
TensorQL mQL, TensorQR mQR,
TensorCL mCL, TensorKR mKR,
TensorO mO, TensorLSE mLSE,
Scale scale) {
using namespace cute;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
dim3 grid(H, B, 1);
dim3 block(256);
int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16;
cudaError_t result;
if (shared_mem >= (48 << 10)) {
result = cudaFuncSetAttribute(
&fmha_mla_reference_kernel<ProblemShape, TensorSeq, TensorPageTable, TensorQL, TensorQR, TensorCL, TensorKR, TensorO, TensorLSE, Scale>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
throw std::runtime_error("couldn't perform smem optin");
}
}
fmha_mla_reference_kernel<<<grid, block, shared_mem>>>(
problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale);
cudaDeviceSynchronize();
result = cudaGetLastError();
if (cudaSuccess != result) {
throw std::runtime_error("couldn't execute reference");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -178,3 +178,96 @@ void reference_abs_diff(
max_diff = result_host[0];
mean_diff = result_host[1] / static_cast<double>(data.size());
}
template<typename Element>
__global__ void reference_rel_diff_kernel(
Element* data, Element* data_ref, size_t count,
double* max_diff, double* sum_diff,
bool print_diff ) {
double thread_max_diff = 0;
double thread_sum_diff = 0;
__shared__ double block_max_diff;
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]);
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}
for (int i = 0; i < blockDim.x; i++) {
if (i == threadIdx.x) {
if (i == 0) {
block_max_diff = thread_max_diff;
block_sum_diff = thread_sum_diff;
}
else {
block_max_diff = fmax(block_max_diff, thread_max_diff);
block_sum_diff += thread_sum_diff;
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(sum_diff, block_sum_diff);
for (;;) {
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
double prev_diff = reinterpret_cast<double const&>(prev);
double new_max_diff = fmax(block_max_diff, prev_diff);
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
if (found == prev) break;
}
}
}
template<typename Element>
void reference_rel_diff(
DeviceAllocation<Element> const& data,
DeviceAllocation<Element> const& data_ref,
double& max_diff, double& mean_diff) {
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
DeviceAllocation<double> result;
result.reset(2);
assert(data.size() == data_ref.size());
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
if (err != cudaSuccess) {
std::cerr << "Memset failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
dim3 block(256, 1, 1);
dim3 grid(1024, 1, 1);
reference_rel_diff_kernel<<<block, grid>>>(
data.get(), data_ref.get(), data.size(),
result.get(), result.get() + 1, kPrintDiff);
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
std::cerr << "Difference kernel failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
double result_host[2];
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
if (err != cudaSuccess) {
std::cerr << "Copy failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
max_diff = result_host[0];
mean_diff = result_host[1] / static_cast<double>(data.size());
}

View File

@@ -0,0 +1,554 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions:
* mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.
Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution.
The kernel leverages:
1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Sparse Tensor Core MMA Instructions
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 16; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// E matrix configuration. Note, E is used to represent metadata tensor.
using ElementE = uint8_t; // Element type for E matrix operand
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120; // Kernel schedule policy
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelScheduleType // Mainloop schedule policy
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutE layout_E;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A_Decompressed;
cutlass::HostTensor<ElementE, cutlass::layout::PackedVectorLayout> block_E;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "80a_blackwell_geforce_mxfp8_bf16_sparse_gemm\n\n"
<< " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize blocks that released to sparse Matrix A and its metadata E
bool initialize_sparse_blocks(const Options &options) {
auto workload = make_shape(options.m,
options.n,
options.k,
1);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
/// Alias SparseConfig and Compressor
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig,
cutlass::arch::Sm120>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
/// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs
CompressorUtility compressor_utility(workload, stride_A);
// Aligned M K dimension size for A and E
int aligned_m_e = compressor_utility.get_metadata_m_physical();
int aligned_k_e = compressor_utility.get_metadata_k_physical();
int aligned_m_a = compressor_utility.get_tensorA_m_physical();
int aligned_k_a = compressor_utility.get_tensorA_k_physical();
/// Layout A and E
layout_A = SparseConfig::fill_layoutA(workload);
layout_E = SparseConfig::fill_layoutE(workload);
block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a));
block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e));
block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k));
initialize_block(block_A_Decompressed.host_view(), seed + 2020);
compressor_utility.structure_sparse_zero_mask_fill(
block_A_Decompressed.host_data(), static_cast<int>(seed + 2021));
block_A_Decompressed.sync_device();
/// Use compressor kernel to generate compressed Matrix A and E
cutlass::Status status { cutlass::Status::kSuccess };
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
{options.m, options.n, options.k, 1},
{block_A_Decompressed.device_data(),
stride_A,
block_A.device_data(),
block_E.device_data()},
{hw_info}
};
// Compress A and E
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
block_A.sync_host();
block_E.sync_host();
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
using namespace cute;
// Initial A, E(metadata) and A_compressed blocks
if(!initialize_sparse_blocks(options)) return false;
// Define B, C and D blocks
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
// Define SFA and SFB tensors layouts
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
return true;
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), layout_A,
block_B.device_data(), stride_B,
block_E.device_data(), layout_E,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
// Initialization
if(!initialize(options))
{
std::cerr << " Initialization failed! " << std::endl;
exit(-1);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 120.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,578 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions:
* mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.
Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution.
The kernel leverages:
1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Sparse Tensor Core MMA Instructions
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::ColumnMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::ColumnMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int outputVectorSize = 32; // Vector size for D matrix
using outputScaleFactor = cutlass::float_ue4m3_t; // Scale factor type for D matrix
// E matrix configuration. Note, E is used to represent metadata tensor.
using ElementE = uint8_t; // Element type for E matrix operand
// Kernel functional config
using ElementCompute = float; // Element type for computation inside mainloop and epilogue
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120; // Kernel schedule policy
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120, // Epilogue schedule policy
cutlass::epilogue::fusion::LinCombBlockScaleFactor< // Epilogue fusion to generate nvfp4 output
outputVectorSize, ElementD, ElementAccumulator, outputScaleFactor, LayoutDTag, ElementC>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelScheduleType // Mainloop schedule policy
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<outputVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutSFD layout_SFD;
LayoutE layout_E;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A_Decompressed;
cutlass::HostTensor<ElementE, cutlass::layout::PackedVectorLayout> block_E;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_SFD;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_reference_SFD;
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm\n\n"
<< " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize blocks that released to sparse Matrix A and its metadata E
bool initialize_sparse_blocks(const Options &options) {
auto workload = make_shape(options.m,
options.n,
options.k,
1);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
/// Alias SparseConfig and Compressor
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig,
cutlass::arch::Sm120>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
/// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs
CompressorUtility compressor_utility(workload, stride_A);
// Aligned M K dimension size for A and E
int aligned_m_e = compressor_utility.get_metadata_m_physical();
int aligned_k_e = compressor_utility.get_metadata_k_physical();
int aligned_m_a = compressor_utility.get_tensorA_m_physical();
int aligned_k_a = compressor_utility.get_tensorA_k_physical();
/// Layout A and E
layout_A = SparseConfig::fill_layoutA(workload);
layout_E = SparseConfig::fill_layoutE(workload);
block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a));
block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e));
block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k));
initialize_block(block_A_Decompressed.host_view(), seed + 2020);
compressor_utility.structure_sparse_zero_mask_fill(
block_A_Decompressed.host_data(), static_cast<int>(seed + 2021));
block_A_Decompressed.sync_device();
/// Use compressor kernel to generate compressed Matrix A and E
cutlass::Status status { cutlass::Status::kSuccess };
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
{options.m, options.n, options.k, 1},
{block_A_Decompressed.device_data(),
stride_A,
block_A.device_data(),
block_E.device_data()},
{hw_info}
};
// Compress A and E
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
block_A.sync_host();
block_E.sync_host();
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
using namespace cute;
// Initial A, E(metadata) and A_compressed blocks
if(!initialize_sparse_blocks(options)) return false;
// Define B, C and D blocks
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
// Define SFA and SFB tensors layouts
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_Normconst.reset(cutlass::make_Coord(1));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_Normconst.at(cutlass::make_Coord(0)) = 2;
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
block_SFD.sync_device();
block_Normconst.sync_device();
return true;
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), layout_A,
block_B.device_data(), stride_B,
block_E.device_data(), layout_E,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
auto tensor_SFD = cute::make_tensor(block_reference_SFD.host_data(), layout_SFD);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D), // TensorD
decltype(tensor_SFD), // TensorSfD
cute::Int<outputVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
// Initialization
if(!initialize(options))
{
std::cerr << " Initialization failed! " << std::endl;
exit(-1);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 120.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
cutlass_example_add_executable(
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu
)
cutlass_example_add_executable(
80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm
80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu
)
endif()

View File

@@ -0,0 +1,869 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Distributed GEMM (DistGEMM) for Blackwell.
This example runs Tensor Parallel GEMMs using the (experimental) Distributed GEMM API in
CUTLASS. For more information, please refer to README.md.
Note that Distributed GEMM assumes an any-to-any NVLink network topology.
To check whether your device is compatible, run:
$ nvidia-smi topo -m
and make sure there's an any-to-any NVLink topology. It would look like this:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X
You should also additionally check if the driver enables peer to peer access:
$ nvidia-smi topo -p2p r
Output should be something like this:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X OK OK OK OK OK OK OK
GPU1 OK X OK OK OK OK OK OK
GPU2 OK OK X OK OK OK OK OK
GPU3 OK OK OK X OK OK OK OK
GPU4 OK OK OK OK X OK OK OK
GPU5 OK OK OK OK OK X OK OK
GPU6 OK OK OK OK OK OK X OK
GPU7 OK OK OK OK OK OK OK X
It is recommended to build this target with the following flag to enable
Grid Dependency Control instructions (GDC) in CUTLASS:
- CUTLASS_ENABLE_GDC_FOR_SM100
Example:
$ mkdir build && cd build
$ cmake .. -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
$ cd examples/82_blackwell_distributed_gemm
$ make
$ ./82_blackwell_distributed_gemm
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
// Distributed GEMM headers
#include "cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp"
#include "cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp"
#include "cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp"
#include "helper.h"
// Distributed GEMM helpers
#include "dist_gemm_helpers.h"
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Distributed GEMM configuration
/////////////////////////////////////////////////////////////////////////////////////////////////
// TP size (= number of processors/GPUs)
using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
// Distributed GEMM tiling/sharding schedule
// Choices:
//
// * All Gather + GEMM:
// * AllGather1D_TilingCD_RotatingA
// * AllGather1D_TilingCD_RotatingB
//
// * GEMM + Reduce Scatter:
// * ReduceScatter1D_TilingA_RotatingC
// * ReduceScatter1D_TilingB_RotatingC
using DistSchedule = cutlass::distributed::schedules::AllGather1D_TilingCD_RotatingA<TP>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
using ElementD = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutD = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_256,_256,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_1,_1>;
// Shape of the tile computed by each SM
using PerSmTileShape_MNK = Shape<_128, _256, _128>;
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
PerSmTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
// Build the mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100
>::CollectiveOp;
// Compose into a kernel
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
// We're going to use the single-device GEMM as reference
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Instantiate Distributed GEMM kernel
using DistGemmKernel = cutlass::distributed::kernel::DistributedGemmKernelWrapper<
GemmKernel,
DistSchedule
>;
using DistGemm = cutlass::distributed::device::DistributedGemmUniversalAdapter<DistGemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
using HostTensorA = typename cutlass::HostTensor<ElementA, LayoutA>;
using HostTensorB = typename cutlass::HostTensor<ElementB, LayoutB>;
using HostTensorC = typename cutlass::HostTensor<ElementC, LayoutC>;
using HostTensorD = typename cutlass::HostTensor<ElementD, LayoutD>;
// Reference GEMM tensors
HostTensorA tensor_A;
HostTensorB tensor_B;
HostTensorC tensor_C;
HostTensorD tensor_D;
HostTensorD tensor_ref_D;
// DistGEMM tensors (multi-device)
HostTensorA tensor_A_arr[TP_];
HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.f, beta = 0.f;
int iterations = 100;
int warmup_iterations = 10;
int m = 16384, n = 106496, k = 16384, l = 1;
float eps = 0.f;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("warmup-iterations", warmup_iterations);
cmd.get_cmd_line_argument("eps", eps);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "82_blackwell_distributed_gemm\n\n"
<< " Blackwell Distributed GEMM (DistGEMM). \n"
<< " For more details please refer to the source file.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch) of the GEMM (default: 1)\n"
<< " --alpha=<f32> Epilogue scalar alpha (default: 1.0)\n"
<< " --beta=<f32> Epilogue scalar beta (default: 0.0)\n"
<< " --iterations=<int> Number of profiling iterations to perform (default: 100)\n"
<< " --warmup-iterations=<int> Number of warmup iterations prior to profiling (default: 10)\n"
<< " --eps=<f32> Threshold for error compared to reference "
<< "GEMM (default: 0.0)\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "82_blackwell_distributed_gemm" << " --m=16384 --n=106496 --k=16384 \n\n";
return out;
}
/// Compute performance in TFLOP/s
double tflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l / TP_;
double tflop = double(flop) / double(1.0e12);
return tflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double tflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double tflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), tflops(tflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed,
bool is_device_tensor = false) {
double scope_max, scope_min;
int bits = cutlass::sizeof_bits<Element>::value;
if (bits == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits <= 16) {
scope_max = 2;
scope_min = -2;
}
else {
scope_max = 8;
scope_min = -8;
}
if (is_device_tensor) {
using Real = typename cutlass::RealType<Element>::Type;
cutlass::reference::device::TensorFillRandomUniform(
view, seed, static_cast<Real>(scope_max), static_cast<Real>(scope_min), 0);
cudaDeviceSynchronize();
} else {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
// Setup (reference) GEMM tensors
auto shape_A = cute::select<0,2,3>(problem_shape);
auto shape_B = cute::select<1,2,3>(problem_shape);
auto shape_C = cute::select<0,1,3>(problem_shape);
auto shape_D = cute::select<0,1,3>(problem_shape);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A);
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
initialize_tensor(tensor_A.device_view(), seed + 2022, /* is_device_tensor = */ true);
initialize_tensor(tensor_B.device_view(), seed + 2023, /* is_device_tensor = */ true);
initialize_tensor(tensor_C.device_view(), seed + 2024, /* is_device_tensor = */ true);
tensor_A.sync_host();
tensor_B.sync_host();
tensor_C.sync_host();
tensor_D.sync_host();
tensor_ref_D.sync_host();
// Set up DistGEMM tensors
auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape);
auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape);
auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape);
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
auto a_coord_device = cutlass::make_Coord(size(local_shape_A), 1);
auto b_coord_device = cutlass::make_Coord(size(local_shape_B), 1);
auto c_coord_device = cutlass::make_Coord(size(local_shape_C), 1);
int primary_device_idx;
CUDA_CHECK(cudaGetDevice(&primary_device_idx));
// Enable any-to-any access
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
int can_access;
CUDA_CHECK(cudaSetDevice(device_idx));
for (int peer_idx = 0; peer_idx < TP_; ++peer_idx) {
if (peer_idx != device_idx) {
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device_idx, peer_idx));
if (not can_access) {
std::cerr << "FAILURE: Device " << device_idx << " can't access device " << peer_idx << "." <<
std::endl;
exit(EXIT_FAILURE);
}
CUDA_CHECK(cudaDeviceEnablePeerAccess(peer_idx, 0));
}
}
tensor_A_arr[device_idx].resize(a_coord_device);
tensor_B_arr[device_idx].resize(b_coord_device);
tensor_C_arr[device_idx].resize(c_coord_device);
tensor_D_arr[device_idx].resize(c_coord_device);
}
CUDA_CHECK(cudaSetDevice(primary_device_idx));
}
/// Commandline options -> Gemm/DistGemm Arguments
using GemmArguments = typename Gemm::Arguments;
GemmArguments gemm_args_from_options(const Options &options) {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
{
{static_cast<ElementCompute>(options.alpha), static_cast<ElementCompute>(options.beta)},
tensor_C.device_data(), stride_C,
tensor_ref_D.device_data(), stride_D
}
};
return arguments;
}
using DistGemmArguments = typename DistGemm::Arguments;
DistGemmArguments dist_gemm_args_from_options(
const Options &options,
int device_idx,
cudaStream_t stream) {
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
auto global_A = cute::make_tensor(tensor_A.device_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto global_B = cute::make_tensor(tensor_B.device_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto global_C = cute::make_tensor(tensor_C.device_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto global_A_device_slice = DistSchedule::get_device_slice_A(global_A, device_idx);
auto global_B_device_slice = DistSchedule::get_device_slice_B(global_B, device_idx);
auto global_C_device_slice = DistSchedule::get_device_slice_C(global_C, device_idx);
auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape);
auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape);
auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape);
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
auto local_stride_A = cutlass::make_cute_packed_stride(StrideA{}, local_shape_A);
auto local_stride_B = cutlass::make_cute_packed_stride(StrideB{}, local_shape_B);
auto local_stride_C = cutlass::make_cute_packed_stride(StrideC{}, local_shape_C);
auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D);
auto local_A = cute::make_tensor(
tensor_A_arr[device_idx].device_data(),
make_layout(local_shape_A, local_stride_A));
auto local_B = cute::make_tensor(
tensor_B_arr[device_idx].device_data(),
make_layout(local_shape_B, local_stride_B));
auto local_C = cute::make_tensor(
tensor_C_arr[device_idx].device_data(),
make_layout(local_shape_C, local_stride_C));
auto local_D = cute::make_tensor(
tensor_D_arr[device_idx].device_data(),
make_layout(local_shape_D, local_stride_D));
// Copy over tensor tiles for the first iteration
cutlass::device_copy(global_A_device_slice, local_A, stream);
cutlass::device_copy(global_B_device_slice, local_B, stream);
cutlass::device_copy(global_C_device_slice, local_C, stream);
DistGemmArguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm, // mode
problem_shape, // problem shape
{
reinterpret_cast<const ElementA*>(local_A.data()),
local_A.stride(),
reinterpret_cast<const ElementB*>(local_B.data()),
local_B.stride()
}, // mainloop
{
{ // epilogue.thread
static_cast<ElementCompute>(options.alpha),
static_cast<ElementCompute>(options.beta)
},
reinterpret_cast<const ElementC*>(local_C.data()),
local_C.stride(),
reinterpret_cast<ElementD*>(local_D.data()),
local_D.stride(),
}, // epilogue
{}, // hw_info
{} // scheduler
};
return arguments;
}
// Gathers results, moves back to the original full-sized D tensor on the primary device.
void gather_results(const Options &options, int device_idx, cudaStream_t stream = nullptr) {
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
// Global dest
auto global_D = cute::make_tensor(tensor_D.device_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto global_D_device_slice = DistSchedule::get_device_slice_D(global_D, device_idx);
// Device_idx local dest
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D);
auto local_D = cute::make_tensor(
tensor_D_arr[device_idx].device_data(),
make_layout(local_shape_D, local_stride_D)
);
// Copy to global dest
cutlass::device_copy(local_D, global_D_device_slice, stream);
}
bool verify(const Options &options) {
tensor_D.sync_host();
tensor_ref_D.sync_host();
bool passed = false;
if (options.eps == 0.f) {
passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
} else {
double err = cutlass::reference::host::TensorRelativeErrorMetric(
tensor_D.host_view(),
tensor_ref_D.host_view());
passed = err < 1e-5;
}
if (options.m <= 64 && options.n <= 64) {
std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n";
std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n";
}
return passed;
}
/// Execute a given example GEMM computation
int run(Options &options) {
int primary_device_idx;
cudaError_t device_get_result = cudaGetDevice(&primary_device_idx);
if (device_get_result != cudaSuccess) {
throw std::runtime_error("cudaGetDevice() failed");
}
initialize(options);
// Reference single-GPU GEMM
Gemm reference_gemm;
cutlass::device_memory::allocation<uint8_t> reference_workspace;
auto reference_arguments = gemm_args_from_options(options);
size_t reference_workspace_size = Gemm::get_workspace_size(reference_arguments);
reference_workspace = cutlass::device_memory::allocation<uint8_t>(reference_workspace_size);
CUTLASS_CHECK(reference_gemm.can_implement(reference_arguments));
CUTLASS_CHECK(reference_gemm.initialize(reference_arguments, reference_workspace.get()));
CUTLASS_CHECK(reference_gemm.run());
using ElementBarrier = typename DistGemm::ElementBarrier;
using ElementFlag = typename DistGemmKernel::ElementFlag;
// Set up per-device streams
cudaStream_t stream_arr[TP_];
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
// Create stream
CUDA_CHECK(cudaStreamCreate(&stream_arr[device_idx]));
}
// Instantiate DistGEMM
DistGemm dist_gemm_arr[TP_]; // Distributed GEMM array for multiple devices
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace_arr[TP_];
cutlass::device_memory::allocation<uint8_t> exclusive_workspace_arr[TP_];
// Cross-device workspace pointer array for gemm.initialize()
void * workspace_ptr_arr[TP_];
void * exclusive_workspace_ptr_arr[TP_];
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
DistGemmArguments arguments_[TP_];
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
exclusive_workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(exclusive_workspace_size);
// Throw workspace pointers into arrays for gemm.initialize()
workspace_ptr_arr[device_idx] = workspace_arr[device_idx].get();
exclusive_workspace_ptr_arr[device_idx] = exclusive_workspace_arr[device_idx].get();
// Zero out exclusive workspace
cudaMemsetAsync(exclusive_workspace_ptr_arr[device_idx], 0, exclusive_workspace_size, stream_arr[device_idx]);
cudaDeviceSynchronize();
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
// Check if the problem size is supported or not
CUTLASS_CHECK(dist_gemm_arr[device_idx].can_implement(arguments_[device_idx]));
#if defined(CUTLASS_ENABLE_GDC_FOR_SM100)
bool launch_with_pdl = true;
#else
bool launch_with_pdl = false;
#endif
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(dist_gemm_arr[device_idx].initialize(
arguments_,
workspace_ptr_arr,
exclusive_workspace_ptr_arr,
device_idx,
stream_arr[device_idx],
launch_with_pdl
));
cudaDeviceSynchronize();
}
// Correctness / Warmup iteration
std::cout << std::endl << " running DistGEMM..." << std::endl;
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx]));
CUDA_CHECK(cudaGetLastError());
gather_results(options, device_idx);
}
std::cout << " running DistGEMM finished without runtime errors" << std::endl;
//// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << std::endl << " Disposition (eps: " << options.eps << "): " <<
(result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0) {
float elapsed_ms = 0.f;
// Warmup
std::cout << " Warming up for " << options.warmup_iterations << " iterations." << std::endl;
for (int warmup_iter = 0; warmup_iter < options.warmup_iterations; ++warmup_iter) {
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
}
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx]));
}
CUDA_CHECK(cudaSetDevice(primary_device_idx));
// Benchmark
std::cout << " Profiling for " << options.iterations << " iterations." << std::endl;
using AtomicBoolean = cuda::atomic<bool>;
AtomicBoolean* atomic_flag_ptr;
CUDA_CHECK(cudaHostAlloc(&atomic_flag_ptr, sizeof(AtomicBoolean), cudaHostAllocPortable));
atomic_flag_ptr->store(false);
cutlass::DistGpuTimer<TP_> timer;
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
cutlass::delay_kernel<<<1, 1, 0, stream_arr[device_idx]>>>(atomic_flag_ptr);
CUDA_CHECK(cudaGetLastError());
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
timer.start(device_idx, stream_arr[device_idx]);
}
atomic_flag_ptr->store(true);
for (int profile_iter = 0; profile_iter < options.iterations; ++profile_iter) {
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
}
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
timer.stop(device_idx, stream_arr[device_idx]);
}
CUDA_CHECK(cudaSetDevice(primary_device_idx));
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
elapsed_ms = max(elapsed_ms, timer.elapsed_millis(device_idx));
}
// Compute average runtime and TFLOPs.
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0);
result.tflops = options.tflops(avg_runtime_s);
auto [local_M, local_N, local_K, local_L] = DistSchedule::get_local_gemm_shape(
cute::make_tuple(options.m, options.n, options.k, options.l));
std::cout << std::endl;
std::cout << " TP: " << TP::value << std::endl;
std::cout << " Problem Size: " <<
options.m << " x " <<
options.n << " x " <<
options.k << " x " <<
options.l << std::endl;
std::cout << " Local GEMM Problem Size: " <<
local_M << " x " <<
local_N << " x " <<
local_K << " x " <<
local_L<< std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS: " << result.tflops << std::endl;
}
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
// and must have compute capability at least 90.
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
int num_devices;
CUDA_CHECK(cudaGetDeviceCount(&num_devices));
if (num_devices < TP_) {
std::cerr << "Distributed GEMM is compiled with TP = " << TP::value << ", but " <<
"found only " << num_devices << " devices." <<
std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
<< "(compute capability 100), "
<< "got compute capability " << props.major * 10 + props.minor << "."
<< std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
run(options);
#endif
return 0;
}

View File

@@ -0,0 +1,32 @@
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
82_blackwell_distributed_gemm
82_blackwell_distributed_gemm.cu
)

View File

@@ -0,0 +1,37 @@
# Blackwell Distributed GEMM
This example implements Tensor Parallel GEMMs for the Hopper architecture with the experimental
[Distributed GEMM](../../include/cutlass/experimental/distributed) API in CUTLASS.
This example requires Blackwell GPUs with an any-to-any NVLink network.
Please refer to [REQUIREMENTS.md](REQUIREMENTS.md) for more information.
By default, the example assumes 8 GPUs (TP=8) and runs an All Gather + GEMM operation, which rotates
operand A. To run with a different number of GPUs or schedule, please refer to
[82_blackwell_distributed_gemm.cu](82_blackwell_distributed_gemm.cu).
## Getting started
Command line arguments are mostly similar to other examples:
```
--m=<int> Sets the M extent of the GEMM
--n=<int> Sets the N extent of the GEMM
--k=<int> Sets the K extent of the GEMM
--l=<int> Sets the L extent (batch) of the GEMM (default: 1)
--alpha=<f32> Epilogue scalar alpha (default: 1.0)
--beta=<f32> Epilogue scalar beta (default: 0.0)
--iterations=<int> Number of profiling iterations to perform (default: 100)
--warmup-iterations=<int> Number of warmup iterations prior to profiling (default: 10)
--eps=<f32> Threshold for error compared to reference GEMM (default: 0.0)
```
Sample run command:
```bash
./82_blackwell_distributed_gemm --m=16384 --n=106496 --k=16384 --warmup-iterations=10 --iterations=100
```
This example follows the [Hopper example](../65_distributed_gemm/) very closely, and only differs in the base GEMM kernel. For
more information you can refer to [that example](../65_distributed_gemm/README.md).

View File

@@ -0,0 +1,86 @@
# Blackwell Distributed GEMM
## Requirements
### Build
Make sure to set up CUTLASS with
support for [Programmatic Dependent Launch (PDL)](../../media/docs/dependent_kernel_launch.md),
that is with the `CUTLASS_ENABLE_GDC_FOR_SM100` flag.
```bash
cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
```
### Minimum software
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
CUDA graph APIs.
### Hardware / driver settings
This example requires Blackwell GPUs with NVLink network.
If you're not sure, first run the following command and make sure your GPU
compute capability is 10.0:
```bash
nvidia-smi --query-gpu=name,compute_cap --format=csv
```
Sample output:
```
name, compute_cap
NVIDIA B200, 10.0
NVIDIA B200, 10.0
NVIDIA B200, 10.0
NVIDIA B200, 10.0
NVIDIA B200, 10.0
NVIDIA B200, 10.0
NVIDIA B200, 10.0
NVIDIA B200, 10.0
```
Then you should make sure there is an NVLink network by checking the GPU network topology,
and making sure there's `NV*` links between every pair of GPUs:
```bash
nvidia-smi topo -m
```
Sample output:
```
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X
```
Finally, check if the driver enables peer to peer access, which should usually be the case,
but it's good to check anyway:
```bash
nvidia-smi topo -p2p r
```
Sample output:
```
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X OK OK OK OK OK OK OK
GPU1 OK X OK OK OK OK OK OK
GPU2 OK OK X OK OK OK OK OK
GPU3 OK OK OK X OK OK OK OK
GPU4 OK OK OK OK X OK OK OK
GPU5 OK OK OK OK OK X OK OK
GPU6 OK OK OK OK OK OK X OK
GPU7 OK OK OK OK OK OK OK X
```

View File

@@ -0,0 +1,607 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A FP16 sparse GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Usage:
$ ./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm --m=8192 --n=8192 --k=8192
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = half_t; // Element type for A matrix operand
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 2 * 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
// E matrix config
using ElementE = cute::uint8_t;
// B matrix configuration
using ElementB = half_t; // Element type for B matrix operand
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = float; // Element type for D matrix operand
using ElementC = float; // Element type for C matrix operand
using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand
using LayoutTagD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassSparseTensorOp; // Operator class tag
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_256,_128,_64>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_1,_1>;
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutTagC, AlignmentC,
ElementD, LayoutTagD, AlignmentD,
cutlass::epilogue::TmaWarpSpecialized2Sm
>::CollectiveOp;
// Build the mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutTagA, AlignmentA,
ElementB, LayoutTagB, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100
>::CollectiveOp;
using ProblemShape = Shape<int,int,int,int>;
// Compose into a kernel
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutTagA,
ElementB,
LayoutTagB,
ElementC,
LayoutTagC,
ElementAccumulator,
ElementAccumulator>;
// Layouts
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Compressor
//
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig,
ArchTag>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
//
// Data members
//
/// Initialization
LayoutA layout_A;
LayoutE layout_E;
StrideA stride_A;
StrideA stride_A_compressed;
StrideE stride_E;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
ProblemShape problem_shape;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A_compressed;
cutlass::DeviceAllocation<typename Gemm::CollectiveMainloop::ElementE> block_E;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k, l;
Options():
help(false),
m(8192), n(8192), k(8192), l(1),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "83_blackwell_sparse_gemm\n\n"
<< " Blackwell FP16 Sparse GEMM example.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "83_blackwell_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = Element(2);
scope_min = Element(0);
}
else if constexpr (bits_input <= 8) {
scope_max = Element(2);
scope_min = Element(-2);
}
else {
scope_max = Element(8);
scope_min = Element(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Make A structured sparse by replacing elements with 0 and compress it
bool sparsify_and_compress()
{
auto [M, N, K, L] = problem_shape;
CompressorUtility compressor_utility(problem_shape, stride_A);
// TensorE
// In unit of ElementE (uint8_t), after alignment requirement
// M-dim: TensorEAtom_M alignment
// K-dim: TensorEAtom_K alignment
int KAlignedE = compressor_utility.get_metadata_k_physical();
int MAlignedE = compressor_utility.get_metadata_m_physical();
// TensorA Compressed
// In unit of ElementARaw, after alignment requirement
// M-dim: TMA alignment
// K-dim: TMA alignment
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
block_A_compressed.reset(M * KAlignedAC * L);
block_E.reset(MAlignedE * KAlignedE * L);
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L));
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L));
// Random 50% fill zero is performed on host
std::vector<ElementA> block_A_host(block_A.size());
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast<int>(seed + 2024));
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments {
problem_shape,
{ block_A.get(),
stride_A,
block_A_compressed.get(),
block_E.get() },
{hw_info} };
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status {cutlass::Status::kSuccess };
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
if (status != cutlass::Status::kSuccess) {
return false;
}
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
block_A.reset(options.m * options.k);
block_B.reset(options.k * options.n);
block_C.reset(options.m * options.n);
block_D.reset(options.m * options.n);
block_ref_D.reset(options.m * options.n);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
// Compress row A and get A_compress and E
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
if (not sparsify_and_compress()) {
return false;
};
// Build the compressed/metadata layouts
layout_A = SparseConfig::fill_layoutA(problem_shape);
layout_E = SparseConfig::fill_layoutE(problem_shape);
return true;
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
problem_shape,
{ block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E },
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
return arguments;
}
bool verify(const Options &options) {
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{options.m, options.n, options.k},
ElementAccumulator(options.alpha),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ref_C,
ref_D);
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
auto init_pass = initialize(options);
if (not init_pass) {
std::cout << "Initialization failure" << std::endl;
exit(EXIT_FAILURE);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (not result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (not (props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,38 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
83_blackwell_sparse_gemm
83_blackwell_sparse_gemm.cu
)
endif()

View File

@@ -0,0 +1,693 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 Sparse GEMM on the NVIDIA Blackwell SM100 architecture.
The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Similar to 83_blackwell_sparse_gemm, this kernel leverages:
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Usage:
$ ./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e2m1_t;
using ElementAPair = cutlass::nv_float4_t<ElementA>; // Element type for A matrix operand
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
// E matrix config
using ElementE = cute::uint8_t;
using LayoutTagE = LayoutTagA;
// B matrix configuration
using ElementB = cutlass::float_e2m1_t;
using ElementBPair = cutlass::nv_float4_t<ElementB>; // Element type for B matrix operand
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// SF
using ElementSF = typename ElementAPair::ScaleFactorType;
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape = Shape<_256,_128,_256>;
// Shape of the threadblocks in a cluster
using ClusterShape = Shape<_2,_1,_1>;
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutTagC, AlignmentC,
ElementD, LayoutTagD, AlignmentD,
cutlass::epilogue::TmaWarpSpecialized2SmNvf4
>::CollectiveOp;
// Build the mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementAPair, LayoutTagA, AlignmentA,
ElementBPair, LayoutTagB, AlignmentB,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100
>::CollectiveOp;
using ProblemShape = Shape<int,int,int,int>;
// Compose into a kernel
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
//
// Blockscale
//
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;
using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF;
using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom;
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Compressor
//
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig,
ArchTag>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideA stride_A_compressed;
StrideE stride_E;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
LayoutA layout_A;
LayoutE layout_E;
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
typename LayoutTagA::Stride stride_factor_A;
typename LayoutTagB::Stride stride_factor_B;
typename LayoutTagE::Stride stride_factor_E;
typename LayoutTagC::Stride stride_factor_C;
typename LayoutTagD::Stride stride_factor_D;
uint64_t seed;
ProblemShape problem_shape;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A;
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A_compressed;
cutlass::HostTensor<ElementE, LayoutTagE> tensor_E;
cutlass::HostTensor<ElementB, LayoutTagB> tensor_B;
cutlass::HostTensor<ElementC, LayoutTagC> tensor_C;
cutlass::HostTensor<ElementSF, LayoutTagA> tensor_SFA;
cutlass::HostTensor<ElementSF, LayoutTagB> tensor_SFB;
cutlass::HostTensor<ElementD, LayoutTagD> tensor_D;
cutlass::HostTensor<ElementD, LayoutTagD> reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k, l;
Options():
help(false),
m(1024), n(1024), k(1024), l(1),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "84a_blackwell_nvfp4_bf16_sparse_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
void initialize_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>){
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
// * Get A B C D size
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = SparseConfig::fill_layoutA(problem_shape);
layout_E = SparseConfig::fill_layoutE(problem_shape);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape);
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape);
// * Get ACompress & E size
CompressorUtility compressor_utility(problem_shape, stride_A);
// TensorE
// In unit of ElementE (uint8_t), after alignment requirement
// M-dim: TensorEAtom_M alignment
// K-dim: TensorEAtom_K alignment
int KAlignedE = compressor_utility.get_metadata_k_physical();
int MAlignedE = compressor_utility.get_metadata_m_physical();
// TensorA Compressed
// In unit of ElementARaw, after alignment requirement
// M-dim: TMA alignment
// K-dim: TMA alignment
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l));
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l));
// * Get SFA & SFB size
auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{})));
auto m_blks = cutlass::ceil_div(options.m, Blk_MN{});
auto n_blks = cutlass::ceil_div(options.n, Blk_MN{});
// * Allocate Tensor
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE);
auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto d_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_comp_coord, stride_factor_A));
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagE>::layout_factory(e_coord, stride_factor_E));
tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D));
reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D), false);
tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(sfa_coord, stride_factor_A));
tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(sfb_coord, stride_factor_B));
// * Random init
initialize_tensor(tensor_A.host_view(), seed + 2021);
initialize_tensor(tensor_B.host_view(), seed + 2022);
initialize_tensor(tensor_C.host_view(), seed + 2023);
initialize_tensor(tensor_SFA.host_view(), seed + 2024);
initialize_tensor(tensor_SFB.host_view(), seed + 2025);
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
// * Random fill 50% A with zero
compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast<int>(seed + 2023));
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
// * Compress
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
problem_shape,
{tensor_A.device_data(),
stride_A,
tensor_A_compressed.device_data(),
tensor_E.device_data()},
{hw_info}
};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status {cutlass::Status::kSuccess };
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
if (status != cutlass::Status::kSuccess) {
return false;
}
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
tensor_E.sync_host();
tensor_A_compressed.sync_host();
return true;
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{
reinterpret_cast<ArrayElementA *>(tensor_A_compressed.device_data()), layout_A,
reinterpret_cast<ArrayElementB *>(tensor_B.device_data()), stride_B,
tensor_E.device_data(), layout_E,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB
},
{
{options.alpha, options.beta},
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A);
auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA);
auto B = make_tensor(make_iterator(tensor_B.host_data()),
make_layout(make_shape(options.n, options.k, options.l), stride_B));
auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettMainloopParams<
ElementAccumulator,
decltype(A),
decltype(B),
decltype(SFA),
decltype(SFB)> mainloop_params{A, SFA, B, SFB};
auto C = make_tensor(make_iterator(tensor_C.host_data()),
make_layout(make_shape(options.m, options.n, options.l), stride_C));
auto D = make_tensor(make_iterator(reference_D.host_data()),
make_layout(make_shape(options.m, options.n, options.l), stride_D));
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(C), // TensorC
decltype(D) // TensorD
> epilogue_params{
options.alpha,
options.beta,
C,
D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
auto init_pass = initialize(options);
if (not init_pass) {
std::cout << "Initialization failure" << std::endl;
exit(EXIT_FAILURE);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (not result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (not (props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,695 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled MXFP8 Sparse GEMM on the NVIDIA Blackwell SM100 architecture.
The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Similar to 83_blackwell_sparse_gemm, this kernel leverages:
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Usage:
$ ./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t;
using ElementAPair = cutlass::mx_float8_t<ElementA>; // Element type for A matrix operand
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
// E matrix config
using ElementE = cute::uint8_t;
using LayoutTagE = LayoutTagA;
// B matrix configuration
using ElementB = cutlass::float_e2m1_t;
using ElementBPair = cutlass::mx_float4_t<ElementB>; // Element type for B matrix operand
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// SF
using ElementSF = typename ElementAPair::ScaleFactorType;
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_256,_128,_256>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_1,_1>;
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutTagC, AlignmentC,
ElementD, LayoutTagD, AlignmentD,
cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4
>::CollectiveOp;
// Build the mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementAPair, LayoutTagA, AlignmentA,
ElementBPair, LayoutTagB, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100
>::CollectiveOp;
using ProblemShape = Shape<int,int,int,int>;
// Compose into a kernel
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
//
// Blockscale
//
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;
using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF;
using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom;
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Compressor
//
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig,
ArchTag>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideA stride_A_compressed;
StrideE stride_E;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
LayoutA layout_A;
LayoutE layout_E;
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
typename LayoutTagA::Stride stride_factor_A;
typename LayoutTagB::Stride stride_factor_B;
typename LayoutTagE::Stride stride_factor_E;
typename LayoutTagC::Stride stride_factor_C;
typename LayoutTagD::Stride stride_factor_D;
uint64_t seed;
ProblemShape problem_shape;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A;
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A_compressed;
cutlass::HostTensor<ElementE, LayoutTagE> tensor_E;
cutlass::HostTensor<ElementB, LayoutTagB> tensor_B;
cutlass::HostTensor<ElementC, LayoutTagC> tensor_C;
cutlass::HostTensor<ElementSF, LayoutTagA> tensor_SFA;
cutlass::HostTensor<ElementSF, LayoutTagB> tensor_SFB;
cutlass::HostTensor<ElementD, LayoutTagD> tensor_D;
cutlass::HostTensor<ElementD, LayoutTagD> reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k, l;
Options():
help(false),
m(1024), n(1024), k(1024), l(1),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "84b_blackwell_mixed_mxfp8_bf16_sparse_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
void initialize_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>){
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
// * Get A B C D size
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = SparseConfig::fill_layoutA(problem_shape);
layout_E = SparseConfig::fill_layoutE(problem_shape);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape);
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape);
// * Get ACompress & E size
CompressorUtility compressor_utility(problem_shape, stride_A);
// TensorE
// In unit of ElementE (uint8_t), after alignment requirement
// M-dim: TensorEAtom_M alignment
// K-dim: TensorEAtom_K alignment
int KAlignedE = compressor_utility.get_metadata_k_physical();
int MAlignedE = compressor_utility.get_metadata_m_physical();
// TensorA Compressed
// In unit of ElementARaw, after alignment requirement
// M-dim: TMA alignment
// K-dim: TMA alignment
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l));
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l));
// * Get SFA & SFB size
auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{})));
auto m_blks = cutlass::ceil_div(options.m, Blk_MN{});
auto n_blks = cutlass::ceil_div(options.n, Blk_MN{});
// * Allocate Tensor
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE);
auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto d_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_comp_coord, stride_factor_A));
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagE>::layout_factory(e_coord, stride_factor_E));
tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D));
reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D), false);
tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(sfa_coord, stride_factor_A));
tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(sfb_coord, stride_factor_B));
// * Random init
initialize_tensor(tensor_A.host_view(), seed + 2021);
initialize_tensor(tensor_B.host_view(), seed + 2022);
initialize_tensor(tensor_C.host_view(), seed + 2023);
initialize_tensor(tensor_SFA.host_view(), seed + 2024);
initialize_tensor(tensor_SFB.host_view(), seed + 2025);
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
// * Random fill 50% A with zero
compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast<int>(seed + 2023));
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
// * Compress
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
problem_shape,
{tensor_A.device_data(),
stride_A,
tensor_A_compressed.device_data(),
tensor_E.device_data()},
{hw_info}
};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status {cutlass::Status::kSuccess };
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
if (status != cutlass::Status::kSuccess) {
return false;
}
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
tensor_E.sync_host();
tensor_A_compressed.sync_host();
return true;
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{
reinterpret_cast<ArrayElementA *>(tensor_A_compressed.device_data()), layout_A,
reinterpret_cast<ArrayElementB *>(tensor_B.device_data()), stride_B,
tensor_E.device_data(), layout_E,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB
},
{
{options.alpha, options.beta},
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A);
auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA);
auto B = make_tensor(make_iterator(tensor_B.host_data()),
make_layout(make_shape(options.n, options.k, options.l), stride_B));
auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettMainloopParams<
ElementAccumulator,
decltype(A),
decltype(B),
decltype(SFA),
decltype(SFB)> mainloop_params{A, SFA, B, SFB};
auto C = make_tensor(make_iterator(tensor_C.host_data()),
make_layout(make_shape(options.m, options.n, options.l), stride_C));
auto D = make_tensor(make_iterator(reference_D.host_data()),
make_layout(make_shape(options.m, options.n, options.l), stride_D));
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementScalingFactor
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(C), // TensorC
decltype(D) // TensorD
> epilogue_params{};
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
auto init_pass = initialize(options);
if (not init_pass) {
std::cout << "Initialization failure" << std::endl;
exit(EXIT_FAILURE);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (not result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (not (props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
84a_blackwell_nvfp4_bf16_sparse_gemm
84a_blackwell_nvfp4_bf16_sparse_gemm.cu
)
cutlass_example_add_executable(
84b_blackwell_mixed_mxfp8_bf16_sparse_gemm
84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu
)
endif()

View File

@@ -158,7 +158,11 @@ foreach(EXAMPLE
77_blackwell_fmha
78_blackwell_emulated_bf16x9_gemm
79_blackwell_geforce_gemm
80_blackwell_geforce_sparse_gemm
81_blackwell_gemm_blockwise
82_blackwell_distributed_gemm
83_blackwell_sparse_gemm
84_blackwell_narrow_precision_sparse_gemm
)
add_subdirectory(${EXAMPLE})

View File

@@ -286,6 +286,18 @@
Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores
* [80_blackwell_geforce_sparse_gemm](80_blackwell_geforce_sparse_gemm/)
Blackwell SM120 sparse MMA kernel targeting GeForce RTX 50 series CUDA Cores
* [83_blackwell_sparse_gemm](83_blackwell_sparse_gemm)
Blackwell SM100 Sparse Gemm kernel
* [84_blackwell_narrow_precision_sparse_gemm](84_blackwell_narrow_precision_sparse_gemm)
Blackwell Block Scaled SM100 Sparse Gemm kernel
# CuTe - Programming Examples
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).

View File

@@ -44,6 +44,11 @@
#include <cuda/atomic>
#include <cuda/std/atomic>
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cuda_host_adapter.hpp"
namespace cutlass {
@@ -115,4 +120,46 @@ struct DistGpuTimer {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic device-to-device data movement kernel based for CuTe tensors.
///
/// NOTE: this kernel assigns one element copy to every thread, and is by no means
/// an efficient way of copying tensors. It should only be used for convenience in
/// reference checks.
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TensorSource, typename TensorDestination>
void device_copy(TensorSource tensor_source,
TensorDestination tensor_destination,
cudaStream_t stream);
template <typename TensorSource, typename TensorDestination>
__global__ void device_copy_kernel(TensorSource const tensor_source,
TensorDestination tensor_destination) {
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
using ElementSrc = typename TensorSource::value_type;
using ElementDst = typename TensorDestination::value_type;
NumericConverter<ElementDst, ElementSrc> converter;
if (linear_idx < size(tensor_source)) {
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
}
}
template <typename TensorSource, typename TensorDestination>
void device_copy(TensorSource tensor_source,
TensorDestination tensor_destination,
cudaStream_t stream) {
assert(tensor_source.size() == tensor_destination.size());
auto numel = tensor_source.size();
static constexpr int NumThreads = 128;
auto grid_size = cute::ceil_div(numel, NumThreads);
dim3 grid(grid_size);
dim3 block(NumThreads);
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
}
} //namespace cutlass

View File

@@ -340,7 +340,7 @@ public:
base_args.epilogue.thread,
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
tensor_c_iter.stride(),
reinterpret_cast<const ElementD*>(tensor_d_iter.data()),
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
tensor_d_iter.stride()
};

View File

@@ -82,7 +82,7 @@ struct DistributedGemmKernelWrapper<
using BaseArguments = typename BaseKernel::Arguments;
using BaseParams = typename BaseKernel::Params;
static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now.");
//static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now.");
static_assert(not cute::is_same_v<typename BaseKernel::ElementC, void>, "DistributedGEMM epilogues must have a source.");
using ElementFlag = uint32_t;

View File

@@ -189,6 +189,100 @@ template<
bool Is2sm = false
>
constexpr bool sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(){
// * 1SM Dense
// * A_K(t) : TileShape_K % 128 == 0
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 128 == 0
// * B_K(n) : TileSize_K % 128 == 0
//
// * 2SM Dense
// * A_K(t) : TileShape_K % 128 == 0
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 256 == 0
// each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned.
// full tile_n needs to be 256 elts aligned
// * B_K(n) : TileShape_K % 128 == 0
//
// * 1SM Sparse
// * A_K(t) : TileShape_K % 256 == 0
// num of physical elems needs to be 128 elts aligned
// num of logical elems needs to be 256 elts aligned
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 128 == 0
// * B_K(n) : TileSize_K % 128 == 0
//
// * 2SM Sparse
// * A_K(t) : TileShape_K % 256 == 0
// num of physical elems needs to be 128 elts aligned
// num of logical elems needs to be 256 elts aligned
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 256 == 0
// each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned.
// full tile_n needs to be 256 elts aligned
// * B_K(n) : TileShape_K % 128 == 0
//
// * Valid TileShape_MNK Dense
// * Notation:
// mma_instruction_tile_shape-cta_tile_shape
// * s128x128x64
// s128x128x32_128x128x128_nn YES
// s128x128x32_128x128x128_nt YES
// s128x128x32_128x128x128_tn YES
// s128x128x32_128x128x128_tt YES
// * s128x256x64
// s128x256x32_128x256x128_nn YES
// s128x256x32_128x256x128_nt YES
// s128x256x32_128x256x128_tn YES
// s128x256x32_128x256x128_tt YES
// * s256x128x64
// s256x128x32_256x128x128_nn YES
// s256x128x32_256x128x128_nt NO (2SM B_N TileSize_N % 256 != 0)
// s256x128x32_256x128x128_tn YES
// s256x128x32_256x128x128_tt NO (2SM B_N TileSize_N % 256 != 0)
// * s256x256x64
// s256x256x32_256x256x128_nn YES
// s256x256x32_256x256x128_nt YES
// s256x256x32_256x256x128_tn YES
// s256x256x32_256x256x128_tt YES
//
// * Valid TileShape_MNK Sparse
// * s128x128x64
// s128x128x64_128x128x128_nn YES
// s128x128x64_128x128x128_nt YES
// s128x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0)
// s128x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0)
// s128x128x64_128x128x256_nn YES
// s128x128x64_128x128x256_nt YES
// s128x128x64_128x128x256_tn YES
// s128x128x64_128x128x256_tt YES
// * s128x256x64
// s128x256x64_128x256x128_nn YES
// s128x256x64_128x256x128_nt YES
// s128x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0)
// s128x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0)
// s128x256x64_128x256x256_nn YES
// s128x256x64_128x256x256_nt YES
// s128x256x64_128x256x256_tn YES
// s128x256x64_128x256x256_tt YES
// * s256x128x64
// s256x128x64_128x128x128_nn YES
// s256x128x64_128x128x128_nt NO (2SM B_N TileSize_N % 256 != 0)
// s256x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0)
// s256x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0)
// s256x128x64_128x128x256_nn YES
// s256x128x64_128x128x256_nt NO (2SM B_N TileSize_N % 256 != 0)
// s256x128x64_128x128x256_tn YES
// s256x128x64_128x128x256_tt NO (2SM B_N TileSize_N % 256 != 0)
// * s256x256x64
// s256x256x64_128x256x128_nn YES
// s256x256x64_128x256x128_nt YES
// s256x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0)
// s256x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0)
// s256x256x64_128x256x256_nn YES
// s256x256x64_128x256x256_nt YES
// s256x256x64_128x256x256_tn YES
// s256x256x64_128x256x256_tt YES
[[maybe_unused]] constexpr int TileShape_M = Is2sm ? size<0>(TileShape_MNK{}) / 2 : size<0>(TileShape_MNK{});
[[maybe_unused]] constexpr int TileShape_N = size<1>(TileShape_MNK{});
[[maybe_unused]] constexpr int TileShape_K = size<2>(TileShape_MNK{});

View File

@@ -432,6 +432,10 @@ public:
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
if constexpr (SwapAB) {
init_M = get<1>(problem_shape_MNK);
init_N = get<0>(problem_shape_MNK);
}
if constexpr (not SwapAB) {
dA = args.dA;
@@ -491,7 +495,7 @@ public:
: args_setup(args.ptr_A, args.ptr_B);
}
else if constexpr (ModeHasScales) {
auto scale_k = 1;
auto scale_k = ceil_div(init_K, args.chunk_size);
ElementScale const* ptr_S = reinterpret_cast<ElementScale const*>(args.ptr_S);
StrideScale dS{};
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS));
@@ -595,7 +599,7 @@ public:
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.chunk_size - 1) / args.chunk_size;
const int scale_k = ceil_div(K, args.chunk_size);
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0));
@@ -659,14 +663,15 @@ public:
return cute::make_tuple(gA_mkl, gB_nkl);
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
auto scale_k = mainloop_params.scale_k;
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l)
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(scale_mn,scale_k,L));
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l)
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(scale_mn,scale_k,L));
Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl);
}
@@ -1217,8 +1222,8 @@ public:
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t M = (SwapAB? get<1>(problem_shape_mnkl) : get<0>(problem_shape_mnkl));
const uint32_t N = (SwapAB? get<0>(problem_shape_mnkl) : get<1>(problem_shape_mnkl));
const uint32_t K = get<2>(problem_shape_mnkl);
// Replace all dims for consistency
@@ -1245,14 +1250,14 @@ public:
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
NonVoidElementScale const* ptr_S = nullptr;
auto scale_k = 1;
auto scale_k = ceil_div(K, mainloop_params.chunk_size);
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale,
prob_shape_scale, prob_stride_scale);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
ElementZero const* ptr_Z = nullptr;
auto scale_k = 1;
auto scale_k = ceil_div(K, mainloop_params.chunk_size);
Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero,
prob_shape_zero, prob_stride_zero);

View File

@@ -426,7 +426,7 @@ public:
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
}
else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;
auto scale_k = ceil_div(K, args.group_size);
ElementScale const* ptr_S = args.ptr_S;
StrideScale dS = args.dS;
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS));
@@ -483,7 +483,7 @@ public:
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.group_size - 1) / args.group_size;
const int scale_k = ceil_div(K, args.group_size);
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));

View File

@@ -622,6 +622,11 @@ public:
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
impl_.producer_expect_transaction(state, transaction_bytes);
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(PipelineState state, uint32_t bytes) {

View File

@@ -452,6 +452,11 @@ public:
return producer_get_barrier(state.index());
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
producer_expect_transaction(state.index(), transaction_bytes);
}
////////////////////
// Consumer APIs
////////////////////
@@ -519,6 +524,14 @@ private:
#endif
}
CUTLASS_DEVICE
void producer_expect_transaction(uint32_t stage, uint32_t transaction_bytes) {
detail::pipeline_check_is_producer(params_.role);
if (params_.is_leader) {
full_barrier_ptr_[stage].expect_transaction(transaction_bytes);
}
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(uint32_t stage, uint32_t bytes) {

View File

@@ -9,15 +9,15 @@ efficient SM100 GEMM kernels targeting these new mma instructions.
Blackwell SM100 has 7 new `tcgen05.mma` instructions. These instructions are 2x to 4x faster then Hopper Architecture's WGMMA instructions.
| Ptx Instruction | Throughput | Notes |
|----------------------------------------------------------------------------------|----------------------------|-------|
|tcgen05.mma.cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts |
|tcgen05.mma.cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts |
|tcgen05.mma.cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts |
|tcgen05.mma.cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts |
|tcgen05.mma.cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts |
|tcgen05.mma.cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts |
|tcgen05.mma.cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts |
| Ptx Instruction | Throughput | Notes |
|---------------------------------------------------------------------------------------|----------------------------|-------|
|tcgen05.mma(.sp).cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts |
|tcgen05.mma(.sp).cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts |
|tcgen05.mma(.sp).cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts |
|tcgen05.mma(.sp).cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts |
|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts |
|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts |
|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts |
For more detailed information see [`tcgen05.mma` PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensorcore-5th-generation-family-instructions).
@@ -27,7 +27,7 @@ For more detailed information see [`tcgen05.mma` PTX documentation](https://docs
Instructions with `kind` modifiers `mxf8f6f4`, `mxf4`, and `nvf4mxf4` perform matrix multiplication operations with scale
factors of the form $D = C +( A \times SFA) * (B \times SFB)$. Scale factors are applied to GEMM-K dimension such that
every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor. For example, an $M\times K$,
every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor (32 or 64 elements for sparse as sparse gemm compress 2x along k-dim). For example, an $M\times K$,
$A$ matrix has an associated $M \times \lceil K/32 \rceil$ SFA matrix; and an $N\times K$ $B$, matrix has an associated
$N \times \lceil K/32 \rceil$ SFB matrix. For block scaled GEMMs, an entry of output D matrix is
$D_{ij} = C_{ij} + \sum_{k} (A_{i,k} \times SFA_{i,k/SV}) \times (B_{j,k}\times SFB_{j,k/SV})$, in index notation, we SV is the scale factor vector size (16 or 32).
@@ -57,12 +57,12 @@ See [PTX documentation for narrow precision data types](https://docs.nvidia.com/
Block scaled MMAs use `mx` and `nv` types which are a pair of float8_t, float6_t, float4_t with 2 of the scale factor data types with a predetermined scale factor vector size. `mx` types follow OCP specification (see [OCP Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)). The following types provided by CUTLASS can be used as inputs to collective builders to generate the block scaled kernels:
**Blackwell Block Scaled Narrow Precision Data Types**
| Mx/Nv Data Type |Scale Factor Type | SF Vector Size | OCP Compliant |
|----------------------------|------------------|----------------|---------------|
| mx_float8_t\<Any F8type\> |float_ue8m0_t |32 | Yes |
| mx_float6_t\<Any F6Type\> |float_ue8m0_t |32 | Yes |
| mx_float4_t |float_ue8m0_t |32 | Yes |
| nv_float4_t |float_ue4m3_t |16 | No |
| Mx/Nv Data Type |Scale Factor Type | SF Vector Size (Dense) | SF Vector Size (Sparse)| OCP Compliant |
|----------------------------|------------------|------------------------|------------------------|---------------|
| mx_float8_t\<Any F8type\> |float_ue8m0_t |32 |64 | Yes |
| mx_float6_t\<Any F6Type\> |float_ue8m0_t |32 |64 | Yes |
| mx_float4_t |float_ue8m0_t |32 |64 | Yes |
| nv_float4_t |float_ue4m3_t |16 |32 | No |
## Layouts, Tensor Alignment Requirements to Target `tcgen05.mma` Instructions
@@ -74,13 +74,18 @@ For legacy types (`tf32`, `f16`, `bf16`, `i8` and `u8`) alignment requirements f
All four layouts (TT, NN, NT, TT) are supported for all legacy data types.
**Table 1: Valid Data Type, Alignment, and Layout Combinations For MMAs with Legacy Types** <a id="legacy_gemm_table" name="legacy_gemm_table"></a>
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|-------------------------------|------------|------------|----------------|-------------|-------------|-------------------------|-----------|
|1 | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | |
|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)|
|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)|
| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|-------------------------------|----------------|------------|------------|----------------|------------------|-------------|-------------------------|---------- |
|[1](#legacy_rows) | Dense | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | |
|[2](#legacy_rows) | Dense | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu) |
|[3](#legacy_rows) | Dense | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|[4](#legacy_rows) | Dense | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu) |
|[5](#legacy_rows) | Dense | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu) |
|[6](#legacy_rows) | Sparse | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 (N) / 8 (T) | 4 | tf32 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f32_f32_f32_f32_f32_tfmma.cu) |
|[7](#legacy_rows) | Sparse | half_t | half_t | TN, NN, NT, TT | 8 (N) / 16 (T) | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f16_f16_f32_f16_f16_hmma.cu) |
|[8](#legacy_rows) | Sparse | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 (N) / 16 (T) | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f16_f16_f32_f16_f16_hmma.cu) |
|[9](#legacy_rows) | Sparse | int8_t | int8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_s8_s8_s32_s8_s8_imma.cu) |
|[10](#legacy_rows) | Sparse | uint8_t | uint8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_s8_s8_s32_s8_s8_imma.cu) |
For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every `tcgen05.mma` instructions.
Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision
@@ -91,203 +96,298 @@ Below tables list valid layout, and alignment values for each A and B data type
instructions supported by CUTLASS.
**Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling** <a id="non_bs_gemm_table" name="non_bs_gemm_table"></a>
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|-------------------------------|----------|----------|----------------|-------------|-------------|-------------------------|-----------|
|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)|
| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|-------------------------------|----------------|----------|----------|----------------|-------------------|-------------|-------------------------|-----------|
|[1](#nonbs_rows_1_2_3_6) | Dense | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[2](#nonbs_rows_1_2_3_6) | Dense | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[3](#nonbs_rows_1_2_3_6) | Dense | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[4](#nonbs_rows_4_7) | Dense | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|[5](#nonbs_rows_5_8) | Dense | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|[6](#nonbs_rows_1_2_3_6) | Dense | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|[7](#nonbs_rows_4_7) | Dense | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|[8](#nonbs_rows_5_8) | Dense | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|[9](#nonbs_rows_9) | Dense | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)|
|[10](#nonbs_rows_1_2_3_6) | Sparse | float4_t | float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu) |
|[11](#nonbs_rows_1_2_3_6) | Sparse | float4_t | float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu) |
|[12](#nonbs_rows_1_2_3_6) | Sparse | float6_t | float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu) |
|[13](#nonbs_rows_4_7) | Sparse | float4_t | float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu) |
|[14](#nonbs_rows_5_8) | Sparse | float8_t | float4_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu) |
|[15](#nonbs_rows_1_2_3_6) | Sparse | float6_t | float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu) |
|[16](#nonbs_rows_4_7) | Sparse | float6_t | float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu) |
|[17](#nonbs_rows_5_8) | Sparse | float8_t | float6_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu) |
|[18](#nonbs_rows_9) | Sparse | float8_t | float8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f8_f8_f32_f16_f16_qmma.cu) |
**Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs** <a id="bs_gemm_table" name="bs_gemm_table"></a>
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test|
|-------------------------|-------------|-------------|----------------|-------------|-------------|-------------------------|------|
|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)|
|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)|
|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)|
|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)|
|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)|
|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)|
|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)|
|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)|
|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)|
|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)|
|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)|
| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test|
|--------------------------|----------------|-------------|-------------|----------------|-------------------|-------------|-------------------------|---------|
|[1](#bs_rows_1) | Dense | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)|
|[2](#bs_rows_2) | Dense | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)|
|[3](#bs_rows_3) | Dense | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)|
|[4](#bs_rows_4_5_7_8_10) | Dense | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)|
|[5](#bs_rows_4_5_7_8_10) | Dense | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)|
|[6](#bs_rows_6_9_11) | Dense | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)|
|[7](#bs_rows_4_5_7_8_10) | Dense | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)|
|[8](#bs_rows_4_5_7_8_10) | Dense | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)|
|[9](#bs_rows_6_9_11) | Dense | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)|
|[10](#bs_rows_4_5_7_8_10) | Dense | mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)|
|[11](#bs_rows_6_9_11) | Dense | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)|
|[12](#bs_rows_1) | Sparse | nv_float4_t | nv_float4_t | TN | 32 (N) / 64 (T) | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_nvf4_nvf4_f32_void_f16_o_tnn.cu) |
|[13](#bs_rows_2) | Sparse | mx_float4_t | mx_float4_t | TN | 32 (N) / 64 (T) | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf4_f32_f16_f16_o_tnn.cu) |
|[14](#bs_rows_3) | Sparse | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf4_f32_f16_f16_q_tnt.cu) |
|[15](#bs_rows_4_5_7_8_10) | Sparse | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu) |
|[16](#bs_rows_4_5_7_8_10) | Sparse | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu) |
|[17](#bs_rows_6_9_11) | Sparse | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu) |
|[18](#bs_rows_4_5_7_8_10) | Sparse | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu) |
|[19](#bs_rows_4_5_7_8_10) | Sparse | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu) |
|[20](#bs_rows_6_9_11) | Sparse | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu) |
|[21](#bs_rows_4_5_7_8_10) | Sparse | mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu) |
|[22](#bs_rows_6_9_11) | Sparse | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf8_f32_f16_f16_q_tnn.cu) |
## MMA tile shapes supported
The alignment restrictions also limit the options for Mma Tile Shapes. Tables below list the supported/valid `MmaTileShape`,
Layout, and Dispatch Policy combinations for each row of [Table 1](#legacy_gemm_table), [Table 2](#non_bs_gemm_table), and [Table 3](#bs_gemm_table).
**Table 4: Valid Tile Shapes and Dispatch Policies for lagacy types (All rows of Table 1)** <a id="legacy_rows" name="legacy_rows"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|------------------|----|----|----|----|------------------------------------|
| 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
**Table 4: Valid Tile Shapes and Dispatch Policies for legacy types (All rows of Table 1)** <a id="legacy_rows" name="legacy_rows"></a>
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|--------------------|----|----|----|----|------------------------------------------|
| Dense | 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Sparse | 1SM | 128x64x(2/4*MMA-K) | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x128x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x192x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 2SM | 256x64x(2/4*MMA-K) | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x128x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x192x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6 of Table 2)** <a id="nonbs_rows_1_2_3_6" name="nonbs_rows_1_2_3_6"></a>
**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6,10,11,12,and 15 of Table 2)** <a id="nonbs_rows_1_2_3_6" name="nonbs_rows_1_2_3_6"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|----------------|----|----|----|----|------------------------------------|
| 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
| Dense | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Sparse | 1SM | 128x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 2SM | 256x128x128 | N | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8 of Table 2)** <a id="nonbs_rows_5_8" name="nonbs_rows_5_8"></a>
**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8,14,and 17 of Table 2)** <a id="nonbs_rows_5_8" name="nonbs_rows_5_8"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|----------------|----|----|----|----|------------------------------------|
| 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
| Dense | 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Sparse | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 2SM | 256x128x128 | Y | Y | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7 of Table 2)** <a id="nonbs_rows_4_7" name="nonbs_rows_4_7"></a>
**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7,13,and 16 of Table 2)** <a id="nonbs_rows_4_7" name="nonbs_rows_4_7"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|----------------|----|----|----|----|------------------------------------|
| 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
| Dense | 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Sparse | 1SM | 128x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 2SM | 256x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9 of Table 2)** <a id="nonbs_rows_9" name="nonbs_rows_9"></a>
**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9,18 of Table 2)** <a id="nonbs_rows_9" name="nonbs_rows_9"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|----------------|----|----|----|----|------------------------------------|
| 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
| Dense | 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
| Dense | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
| Sparse | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
| Sparse | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
| Sparse | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 and 12 of Table 3)** <a id="bs_rows_1" name="bs_rows_1"></a>
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|----------------------------------------------|
| Dense | 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
| Dense | 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
| Dense | 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
| Dense | 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
| Dense | 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
| Dense | 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
| Sparse | 1SM | 128x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
| Sparse | 1SM | 128x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
| Sparse | 2SM | 256x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
| Sparse | 2SM | 256x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 of Table 3)** <a id="bs_rows_1" name="bs_rows_1"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|---------------|----|----|----|----|----------------------------------------|
| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 and 13 of Table 3)** <a id="bs_rows_2" name="bs_rows_2"></a>
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|----------------------------------------------|
| Dense | 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
| Dense | 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
| Dense | 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
| Dense | 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
| Dense | 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
| Dense | 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
| Sparse | 1SM | 128x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
| Sparse | 1SM | 128x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
| Sparse | 2SM | 256x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
| Sparse | 2SM | 256x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 of Table 3)** <a id="bs_rows_2" name="bs_rows_2"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|---------------|----|----|----|----|----------------------------------------|
| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 and 14 of Table 3)** <a id="bs_rows_3" name="bs_rows_3"></a>
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|--------------------------------------------------|
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x192x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 of Table 3)** <a id="bs_rows_3" name="bs_rows_3"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|---------------|----|----|----|----|--------------------------------------------|
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10, 15, 16, 18, 19, and 21 of Table 3)** <a id="bs_rows_4_5_7_8_10" name="bs_rows_4_5_7_8_10"></a>
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|--------------------------------------------------|
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x192x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10 of Table 3)** <a id="bs_rows_4_5_7_8_10" name="bs_rows_4_5_7_8_10"></a>
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|--------|---------------|----|----|----|----|--------------------------------------------|
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11 of Table 3)** <a id="bs_rows_6_9_11" name="bs_rows_6_9_11"></a>
| 1/2 SM | Mma Tile Shape | TN| TT | NT | NN | Dispatch Policy |
|--------|---------------|----|----|----|----|--------------------------------------------|
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11, 17, 20, and 22 of Table 3)** <a id="bs_rows_6_9_11" name="bs_rows_6_9_11"></a>
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|----------------|--------|----------------|----|----|----|----|--------------------------------------------------|
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
## Epilogue config supported
**Table 14: Epilogue Dispatch Policy** <a id="epi_dispatch" name="epi_dispatch"></a>
| 1/2 SM | Epilogue Dispatch Policy |
|--------|------------------------------------------|
| 1SM | cutlass::epilogue::TmaWarpSpecialized1Sm |
| 1SM | cutlass::epilogue::NoSmemWarpSpecialized1Sm |
| 2SM | cutlass::epilogue::TmaWarpSpecialized2Sm |
| 2SM | cutlass::epilogue::NoSmemWarpSpecialized2Sm |
| Dense / Sparse | Legacy / Narrow Precision | 1/2 SM | Epilogue Dispatch Policy |
|----------------|-----------------------------|--------|----------------------------------------------------|
| Dense | Legacy & Narrow Precision | 1SM | `cutlass::epilogue::TmaWarpSpecialized1Sm` |
| Dense | Legacy & Narrow Precision | 1SM | `cutlass::epilogue::NoSmemWarpSpecialized1Sm` |
| Dense | Legacy & Narrow Precision | 2SM | `cutlass::epilogue::TmaWarpSpecialized2Sm` |
| Dense | Legacy & Narrow Precision | 2SM | `cutlass::epilogue::NoSmemWarpSpecialized2Sm` |
| Sparse | Legacy | 1SM | `cutlass::epilogue::TmaWarpSpecialized1Sm` |
| Sparse | Legacy | 1SM | `cutlass::epilogue::NoSmemWarpSpecialized1Sm` |
| Sparse | Legacy | 2SM | `cutlass::epilogue::TmaWarpSpecialized2Sm` |
| Sparse | Legacy | 2SM | `cutlass::epilogue::NoSmemWarpSpecialized2Sm` |
| Sparse | Narrow Precision (nvf4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmNvf4` |
| Sparse | Narrow Precision (nvf4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmNvf4` |
| Sparse | Narrow Precision (mxf4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmMxf4` |
| Sparse | Narrow Precision (mxf4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmMxf4` |
| Sparse | Narrow Precision (mxf8f6f4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4` |
| Sparse | Narrow Precision (mxf8f6f4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4` |
**Table 15: Epilogue PerSmTileShape_MNK** <a id="epi_persmtileshape" name="epi_persmtileshape"></a>
| 1/2 SM | MMA tile Shape | PerSmTileShape_MNK |
@@ -314,14 +414,16 @@ MMA_TileShape_K is is generally 4 * MMA-Instruction-K. It depends on the config
### Auto Kernel Dispatch Policies
In addition to direct dispatch policies listed above, the user can also use auto policies for both non-block scaled narrow-precision
GEMMs, and block scaled narrow-precision GEMMs.
GEMMs (both sparse and dense), and block scaled narrow-precision GEMMs (only dense).
CUTLASS will do its best to find the most efficient kernel for given parameters, however, the preferred method for building
these kernels is to use direct kernel dispatch policies shown in the above tables.
* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma`.
* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma(.sp)`.
* `KernelTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
* `KernelTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
* `KernelSparseTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma.sp` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
* `KernelSparseTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma.sp` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueScheduleAuto`.
@@ -330,16 +432,23 @@ Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueSche
For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel). An example dense GEMM can be found:
1. [Blackwell FP16 GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/).
An example sparse GEMM can be found:
1. [Blackwell FP16 Sparse GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/83_blackwell_sparse_gemm/).
Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface
(as described in [CUTLASS 3.0 GEMM API](gemm_api_3x.md#cutlass-30-gemm-api)). However, special attention needs to be given to
A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel
which are listed above.
Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory:
Several examples of block scaled dense GEMM kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory:
1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
2. [NVF4 Gemm with block scaling and NVF4 output matrix](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
Several examples of block scaled sparse GEMM kernels can be found in [examples/84_blackwell_narrow_precision_sparse_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm) directory:
1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
2. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described
[here](gemm_api_3x.md#collective-builder-for-collectivemmas) with a small difference for Collective MMA builder interface.
As in all Blackwell kernels, the `TileShape_MNK` argument expects the `MmaTileShape_MNK` which is the tile shape needed

View File

@@ -28,7 +28,6 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cuda_runtime_api.h>
#include "cutlass_unit_test.h"
@@ -59,7 +58,10 @@ std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &deviceProperti
int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor;
if (deviceMajorMinor) {
int32_t clock_MHz = deviceProperties.clockRate / 1000;
int32_t clock_MHz;
int32_t clock_KHz;
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
clock_MHz = clock_KHz / 1000;
out << "GPU(compute_"
<< deviceMajorMinor << ", "
<< deviceProperties.multiProcessorCount << " SMs @ " << clock_MHz << " MHz)";

View File

@@ -29,22 +29,25 @@
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
add_custom_target(
cutlass_test_unit_gemm_device_sm100_bssp
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse
DEPENDS
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
cutlass_test_unit_gemm_device_sm100_bssp_streamk
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -57,7 +60,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -70,7 +73,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -83,7 +86,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -96,7 +99,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -109,7 +112,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -127,7 +130,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -140,7 +143,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -148,10 +151,32 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
sm100_bssp_gemm_mxf8_mxf4_f32_f16_mxf8_q_tnt.cu
sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu
sm100_bssp_gemm_mxf8_mxf4_f32_f32_f32_q_tnt.cu
sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu
sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu
sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
@@ -162,7 +187,16 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_streamk
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
BATCH_SOURCES ON
BATCH_SIZE 1

View File

@@ -26,18 +26,19 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
add_subdirectory(narrow_precision)
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
add_custom_target(
cutlass_test_unit_gemm_device_sm100_sp
cutlass_test_unit_gemm_device_sm100_sparse
DEPENDS
cutlass_test_unit_gemm_device_sm100_sp_general
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
cutlass_test_unit_gemm_device_sm100_sp_streamk
cutlass_test_unit_gemm_device_sm100_sparse_general
cutlass_test_unit_gemm_device_sm100_sparse_streamk
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sp_general
cutlass_test_unit_gemm_device_sm100_sparse_general
# No batching of source to control compiler memory usage
BATCH_SOURCES ON
@@ -52,23 +53,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
# No batching of source to control compiler memory usage
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu
sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu
sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu
sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu
sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu
sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sp_streamk
cutlass_test_unit_gemm_device_sm100_sparse_streamk
# No batching of source to control compiler memory usage
BATCH_SOURCES ON

View File

@@ -0,0 +1,77 @@
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
add_custom_target(
cutlass_test_unit_gemm_device_sm100_sparse_narrow_precision
DEPENDS
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu
sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu
sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu
sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu
sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu
sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu
sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu
sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu
)
endif()

View File

@@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@@ -31,7 +31,8 @@
/* \file
\brief Command line options for performance test program
*/
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <algorithm>
#include <fstream>
#include <set>
@@ -165,9 +166,11 @@ void Options::Device::print_usage(std::ostream &out) const {
break;
}
else {
int32_t clock_KHz;
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
out << " [" << idx << "] - "
<< prop.name << " - SM " << prop.major << "." << prop.minor << ", "
<< prop.multiProcessorCount << " SMs @ " << (prop.clockRate / 1000.0) << " MHz, "
<< prop.multiProcessorCount << " SMs @ " << (clock_KHz / 1000.0) << " MHz, "
<< "L2 cache: " << (prop.l2CacheSize >> 20) << " MB, Global Memory: " << (prop.totalGlobalMem >> 30) << " GB"
<< std::endl;
}
@@ -216,9 +219,11 @@ void Options::Device::print_options(std::ostream &out, int indent) const {
for (int device : devices) {
out << device << ',';
}
int32_t clock_KHz;
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
out
<< "\n"
<< indent_str(indent) << "clock: " << int(double(properties[0].clockRate) / 1000.0) << "\n"
<< indent_str(indent) << "clock: " << int(double(clock_KHz) / 1000.0) << "\n"
<< indent_str(indent) << "compute-capability: " << compute_capability(0) << "\n";
}

View File

@@ -109,7 +109,8 @@ bool BlockCompareEqual(
Element const *ptr_B,
size_t capacity,
int grid_size = 0,
int block_size = 0) {
int block_size = 0,
cudaStream_t stream = nullptr) {
int equal_flag = 1;
int *device_equal_flag = nullptr;
@@ -146,7 +147,9 @@ bool BlockCompareEqual(
dim3 grid(grid_size, 1, 1);
dim3 block(block_size, 1, 1);
kernel::BlockCompareEqual<Element><<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity);
kernel::BlockCompareEqual<Element><<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity);
cudaStreamSynchronize(stream);
if (cudaMemcpy(
&equal_flag,
@@ -175,7 +178,8 @@ bool BlockCompareRelativelyEqual(
Element epsilon,
Element nonzero_floor,
int grid_size = 0,
int block_size = 0) {
int block_size = 0,
cudaStream_t stream = nullptr) {
int equal_flag = 1;
int *device_equal_flag = nullptr;
@@ -212,7 +216,7 @@ bool BlockCompareRelativelyEqual(
dim3 grid(grid_size, 1, 1);
dim3 block(block_size, 1, 1);
kernel::BlockCompareRelativelyEqual<Element><<< grid, block >>>(
kernel::BlockCompareRelativelyEqual<Element><<< grid, block, 0, stream >>>(
device_equal_flag,
ptr_A,
ptr_B,
@@ -221,6 +225,8 @@ bool BlockCompareRelativelyEqual(
nonzero_floor
);
cudaStreamSynchronize(stream);
if (cudaMemcpy(
&equal_flag,
device_equal_flag,

View File

@@ -232,6 +232,8 @@ ComputeType TensorTransformReduce(
workspace, identity, workspace_size, reduce
);
cudaStreamSynchronize(stream);
if (copy_out) {
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
@@ -285,6 +287,8 @@ ComputeType TensorTransformReduce(
workspace, identity, workspace_size, reduce
);
cudaStreamSynchronize(stream);
if (copy_out) {
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {