mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
@@ -13,11 +13,16 @@
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu).
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
|
||||
@@ -50,11 +50,16 @@ architecture.
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu).
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
|
||||
@@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@@ -429,7 +429,7 @@ void initialize(Options const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
@@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions {
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@@ -347,7 +347,7 @@ void initialize(Options const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
|
||||
@@ -288,7 +288,7 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
|
||||
void initialize(MixedDtypeOptions const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@@ -313,7 +313,7 @@ void initialize(MixedDtypeOptions const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
@@ -208,20 +208,6 @@ bool initialize_tensor(
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
@@ -232,10 +218,8 @@ bool initialize_scale(
|
||||
float scope_max = 1.0f, scope_min = 1.0f;
|
||||
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
scope_max = max_dequant_val / elt_max_f;
|
||||
scope_min = min_dequant_val / elt_max_f;
|
||||
scope_max = 2.f;
|
||||
scope_min = 0.1f;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
@@ -120,8 +120,7 @@
|
||||
#include "helper.h"
|
||||
|
||||
// Distributed GEMM helpers
|
||||
#include "util/benchmark.h"
|
||||
#include "util/device_copy.h"
|
||||
#include "dist_gemm_helpers.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief generic device-to-device data movement kernel based for CuTe tensors.
|
||||
|
||||
NOTE: this kernel assigns one element copy to every thread, and is by no means
|
||||
an efficient way of copying tensors. It should only be used for convenience in
|
||||
reference checks.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
__global__ void device_copy_kernel(TensorSource const tensor_source,
|
||||
TensorDestination tensor_destination) {
|
||||
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
using ElementSrc = typename TensorSource::value_type;
|
||||
using ElementDst = typename TensorDestination::value_type;
|
||||
NumericConverter<ElementDst, ElementSrc> converter;
|
||||
if (linear_idx < size(tensor_source)) {
|
||||
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream) {
|
||||
|
||||
assert(tensor_source.size() == tensor_destination.size());
|
||||
|
||||
auto numel = tensor_source.size();
|
||||
static constexpr int NumThreads = 128;
|
||||
auto grid_size = cute::ceil_div(numel, NumThreads);
|
||||
|
||||
dim3 grid(grid_size);
|
||||
dim3 block(NumThreads);
|
||||
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
||||
@@ -374,7 +374,7 @@ void allocate(Options const& options) {
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
const int scale_k = 1;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
@@ -510,7 +510,7 @@ void initialize(Options &options) {
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
@@ -519,13 +519,13 @@ void initialize(Options &options) {
|
||||
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
const int scale_k = 1;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream);
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream);
|
||||
}
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
@@ -619,7 +619,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
@@ -676,6 +676,7 @@ bool verify(Options const& options) {
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
@@ -712,7 +713,7 @@ bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
|
||||
@@ -341,7 +341,7 @@ void allocate(Options const& options) {
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
const int scale_k = 1;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
@@ -479,7 +479,7 @@ void initialize(Options& options) {
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
@@ -565,7 +565,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
@@ -617,6 +617,7 @@ bool verify(Options const& options) {
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
@@ -630,11 +631,11 @@ bool verify(Options const& options) {
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
const int scale_k = 1;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
@@ -659,7 +660,7 @@ bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
|
||||
@@ -282,7 +282,7 @@ void allocate(Options const& options) {
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
const int scale_k = 1;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
@@ -418,7 +418,7 @@ void initialize(Options &options) {
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
@@ -485,7 +485,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
@@ -542,6 +542,7 @@ bool verify(Options const& options) {
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
@@ -555,11 +556,11 @@ bool verify(Options const& options) {
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
const int scale_k = 1;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
@@ -584,7 +585,7 @@ bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
|
||||
@@ -50,6 +50,7 @@ set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10)
|
||||
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
|
||||
|
||||
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
|
||||
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
@@ -69,6 +70,7 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
@@ -89,6 +91,7 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
@@ -109,4 +112,5 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
@@ -7,11 +7,11 @@ This example shows how to perform Grouped GEMMs on Hopper when A and B have diff
|
||||
- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B.
|
||||
- if scales and zero-points are included, also pass the array of their strides in the arguments.
|
||||
|
||||
Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs.
|
||||
Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling.
|
||||
|
||||
## Upcoming features
|
||||
|
||||
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
|
||||
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed.
|
||||
|
||||
## Copyright
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ public:
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
MixedDtypeOptions::parse(argc, args);
|
||||
|
||||
@@ -71,6 +72,7 @@ public:
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --c=<int> Sets the chunk size for scaling the quantized weights\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems\n"
|
||||
<< " --mode=<int> The mode to run the gemm\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
@@ -183,11 +185,6 @@ void grouped_mixed_dtype_profiling(
|
||||
|
||||
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta\n";
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n';
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << '\n'
|
||||
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
|
||||
<< " GFLOPS : " << result.gflops << '\n';
|
||||
|
||||
832
examples/77_blackwell_fmha/77_blackwell_mla.cu
Normal file
832
examples/77_blackwell_fmha/77_blackwell_mla.cu
Normal file
@@ -0,0 +1,832 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the
|
||||
NVIDIA Blackwell Architecture.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <cmath>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_mla_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int k = 256;
|
||||
int split_kv = -1; // number of split along k dim.
|
||||
bool is_var_split_kv = false;
|
||||
int max_split_kv = 16;
|
||||
int page = -1;
|
||||
float spread = 0.2f;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_c = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (k == -1) k = defaults.k;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
|
||||
cmd.get_cmd_line_argument("page", page, defaults.page);
|
||||
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
|
||||
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
|
||||
if (page == -1) {
|
||||
is_var_split_kv = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv);
|
||||
if (is_var_split_kv == true) {
|
||||
split_kv = max_split_kv;
|
||||
}
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_mla\n\n"
|
||||
<< " This example showcases the use of CUTLASS for fused multi-head latent\n"
|
||||
<< " attention kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --page=<int> Enables paging and sets the page size\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --spread=<float> Relative spread away from K for paging\n"
|
||||
<< " --split_kv=<int> Split KV factor\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 64; i ++) {
|
||||
for (int j = 0; j < 64; j++) {
|
||||
data[j + 64*i] = static_cast<Element>((double) (i % 9));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class TileShape,
|
||||
class PersistenceOption = IsPersistent<true>
|
||||
>
|
||||
struct Runner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#elif FP16
|
||||
using Element = cutlass::half_t;
|
||||
#else
|
||||
#error "Must either define FP8 or FP16"
|
||||
#endif
|
||||
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler = std::conditional_t<
|
||||
PersistenceOption::value,
|
||||
Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler
|
||||
>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler
|
||||
>;
|
||||
using Operation = cutlass::fmha::device::MLA<Kernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q_latent;
|
||||
StrideK stride_C_latent;
|
||||
StrideQ stride_Q_rope;
|
||||
StrideK stride_K_rope;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
StrideLSE stride_PT;
|
||||
|
||||
uint64_t seed = 0;
|
||||
|
||||
int page_size = -1;
|
||||
int page_count = -1;
|
||||
|
||||
// We allocate Q and C as first latent, then rope
|
||||
// This means that we offset the pointer by HeadDim_latent to get the rope
|
||||
// portion
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_C;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<int> block_seq;
|
||||
DeviceAllocation<int> block_PT;
|
||||
DeviceAllocation<int> block_split_kv;
|
||||
DeviceAllocation<int> block_accum_split_len;
|
||||
DeviceAllocation<ElementAcc> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAcc> block_ref_LSE;
|
||||
|
||||
ElementAcc scale;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
int page_K = K;
|
||||
int page_B = B;
|
||||
if (block_PT.get() != nullptr) {
|
||||
page_K = page_size;
|
||||
page_B = page_count;
|
||||
}
|
||||
|
||||
Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
cute::make_tuple(H, D_latent, B),
|
||||
stride_Q_latent);
|
||||
|
||||
Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent),
|
||||
cute::make_tuple(H, D_rope, B),
|
||||
stride_Q_rope);
|
||||
|
||||
Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()),
|
||||
cute::make_tuple(page_K, D_latent, page_B),
|
||||
stride_C_latent);
|
||||
|
||||
Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent),
|
||||
cute::make_tuple(page_K, D_rope, page_B),
|
||||
stride_K_rope);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
cute::make_tuple(H, D_latent, B),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
cute::make_tuple(H, B),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mSeq = make_tensor(make_gmem_ptr(static_cast<int*>(block_seq.get())), make_shape(B));
|
||||
Tensor mPT = make_tensor(make_gmem_ptr(static_cast<int*>(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT);
|
||||
|
||||
fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
#ifdef B2B
|
||||
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#else
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#endif
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
bool passed_LSE = true;
|
||||
#ifndef B2B
|
||||
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_LSE) {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
|
||||
ProblemShape initialize(const Options& options) {
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// the scale is based on the non-absorbed sizes, change as appropriate
|
||||
// we can't determine this parameter from the info we have, it's an input
|
||||
int D_non_latent = 128;
|
||||
scale = static_cast<decltype(scale)>(1.0 / sqrt(1.0 * (D_non_latent + D_rope)));
|
||||
// Shape (H, D, B)
|
||||
stride_Q_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
|
||||
stride_Q_rope = stride_Q_latent;
|
||||
stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
|
||||
block_Q.reset(static_cast<size_t>(options.b) * H * (D_latent + D_rope));
|
||||
block_O.reset(static_cast<size_t>(options.b) * H * D_latent);
|
||||
block_LSE.reset(static_cast<size_t>(options.b) * H);
|
||||
block_ref_O.reset(static_cast<size_t>(options.b) * H * D_latent);
|
||||
block_ref_LSE.reset(static_cast<size_t>(options.b) * H);
|
||||
|
||||
if (options.page == -1) {
|
||||
|
||||
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(options.k) * (D_latent + D_rope));
|
||||
stride_K_rope = stride_C_latent;
|
||||
|
||||
block_C.reset(static_cast<size_t>(options.b) * options.k * (D_latent + D_rope));
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
float spread = options.spread;
|
||||
int max_K = static_cast<int>((1 + spread) * K);
|
||||
int min_K = static_cast<int>((1 - spread) * K);
|
||||
page_size = options.page;
|
||||
page_count = B * ceil_div(max_K, page_size);
|
||||
stride_PT = cute::make_stride(_1{}, page_count);
|
||||
|
||||
std::vector<int> host_seq(B);
|
||||
std::vector<int> host_PT(page_count * B);
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
int seq = min_K + rand() % (max_K - min_K + 1);
|
||||
host_seq[i] = seq;
|
||||
for (int j = 0; j < ceil_div(seq, page_size); j++) {
|
||||
host_PT[page_count * i + j] = i + j * B;
|
||||
}
|
||||
}
|
||||
|
||||
block_seq.reset(host_seq.size());
|
||||
block_seq.copy_from_host(host_seq.data(), host_seq.size());
|
||||
block_PT.reset(host_PT.size());
|
||||
block_PT.copy_from_host(host_PT.data(), host_PT.size());
|
||||
|
||||
get<1>(problem_shape) = max_K;
|
||||
|
||||
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, page_size * static_cast<int64_t>((D_latent + D_rope)));
|
||||
stride_K_rope = stride_C_latent;
|
||||
|
||||
block_C.reset(page_count * page_size * static_cast<int64_t>((D_latent + D_rope)));
|
||||
|
||||
if (options.is_var_split_kv == true) {
|
||||
std::vector<int> host_split_kv(B);
|
||||
for(int i = 0; i < B; ++i) {
|
||||
auto len = host_seq[i];
|
||||
int split = ceil_div(options.max_split_kv, ceil_div(max_K, len));
|
||||
host_split_kv[i] = split;
|
||||
}
|
||||
block_split_kv.reset(B);
|
||||
block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size());
|
||||
}
|
||||
}
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_C, seed + 2022, options.init_style_c);
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShape problem_shape = initialize(options);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ scale,
|
||||
block_Q.get(), stride_Q_latent,
|
||||
block_Q.get() + D_latent, stride_Q_rope,
|
||||
block_C.get(), stride_C_latent,
|
||||
block_C.get() + D_latent, stride_K_rope,
|
||||
block_seq.get(),
|
||||
block_PT.get(), stride_PT,
|
||||
page_count, page_size},
|
||||
{ block_O.get(),
|
||||
stride_O,
|
||||
block_LSE.get(),
|
||||
stride_LSE},
|
||||
hw_info,
|
||||
options.split_kv,
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr
|
||||
};
|
||||
if (options.split_kv < 0 && !options.is_var_split_kv) {
|
||||
Operation::set_split_kv(arguments);
|
||||
}
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 1.0;
|
||||
flops *= B;
|
||||
flops *= K;
|
||||
flops *= H;
|
||||
flops *= 2.0;
|
||||
flops *= (2.0 * D_latent + D_rope);
|
||||
|
||||
double bytes_q = sizeof(Element);
|
||||
bytes_q *= B;
|
||||
bytes_q *= H;
|
||||
bytes_q *= (D_latent + D_rope);
|
||||
double bytes_c = sizeof(Element);
|
||||
bytes_c *= B;
|
||||
bytes_c *= options.k; // K may be max_K here
|
||||
bytes_c *= (D_latent + D_rope);
|
||||
double bytes_o = sizeof(ElementOut);
|
||||
bytes_o *= B;
|
||||
bytes_o *= H;
|
||||
bytes_o *= D_latent;
|
||||
double bytes = bytes_q + bytes_c + bytes_o;
|
||||
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.tbytes_s = tbytes_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms * 1e3 << " us, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
Runner<decltype(shape), decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using NumHeads = _128;
|
||||
using HeadDimLatent = _512;
|
||||
using HeadDim = Shape<HeadDimLatent, _64>;
|
||||
|
||||
std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " ";
|
||||
std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " ";
|
||||
std::cout << "Q 1 K " << options.k << " Gen None ";
|
||||
std::cout << "Split " << options.split_kv << " Gen None ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
using Blocking = _128;
|
||||
std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{});
|
||||
std::string individual = " individual";
|
||||
std::string persistent = " persistent";
|
||||
#if FP8
|
||||
name += " fp8";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
#elif FP16
|
||||
name += " fp16";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
run_mla(options, hw_info);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -35,6 +35,10 @@ set_property(
|
||||
SOURCE 77_blackwell_fmha_gen.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
|
||||
|
||||
set_property(
|
||||
SOURCE 77_blackwell_mla.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
|
||||
@@ -48,58 +52,69 @@ set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")))
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_fp8
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8)
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_fp8
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8)
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_fp16
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
foreach(PREC fp8 fp16)
|
||||
string(TOUPPER "${PREC}" PREC_MACRO)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_fp16
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
endif()
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_${PREC}
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_${PREC}
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_2sm_cpasync_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_b2b_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
|
||||
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -22,6 +22,24 @@ The `apply_mask` function is called with the accumulator of the first GEMM and t
|
||||
It is well-suited for applying masks or activations.
|
||||
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
|
||||
|
||||
# MLA Inference for Blackwell
|
||||
|
||||
This sample provides code for fused multi-head latent attention inference in
|
||||
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
|
||||
It supports fp16, bf16, and fp8 input and output types.
|
||||
|
||||
To accomodate the large output accumulator due to the large latent head dimension,
|
||||
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
|
||||
|
||||
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
|
||||
for support of any power-of-two page size less than or equal to 128.
|
||||
With paging, the code also supports variable sequence length.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel.
|
||||
|
||||
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
|
||||
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
92
examples/77_blackwell_fmha/common/pow_2.hpp
Normal file
92
examples/77_blackwell_fmha/common/pow_2.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cutlass::fmha {
|
||||
|
||||
struct Pow2 {
|
||||
int n;
|
||||
int log2_n;
|
||||
|
||||
explicit CUTE_DEVICE Pow2(int n) : n(n) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
log2_n = __ffs(n) - 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE T operator *(T const& b) const {
|
||||
return n * b;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
|
||||
if constexpr (N & (N - 1) == 0) {
|
||||
return Pow2{n * N};
|
||||
}
|
||||
return n * N;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
|
||||
return a >> b.log2_n;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
|
||||
return a & (b.n - 1);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
|
||||
return a < b.n;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void print(Pow2 const& a) {
|
||||
printf("2^%d", a.log2_n);
|
||||
}
|
||||
|
||||
} // end namespace cutlass::fmha
|
||||
|
||||
namespace cute {
|
||||
|
||||
template <>
|
||||
struct is_integral<cutlass::fmha::Pow2> : true_type {};
|
||||
|
||||
} // end namespace cute
|
||||
357
examples/77_blackwell_fmha/device/sm100_mla.hpp
Normal file
357
examples/77_blackwell_fmha/device/sm100_mla.hpp
Normal file
@@ -0,0 +1,357 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
int max_splits = ceil_div(K, 128);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
197
examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp
Normal file
197
examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp
Normal file
@@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
File diff suppressed because it is too large
Load Diff
160
examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp
Normal file
160
examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp
Normal file
@@ -0,0 +1,160 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
|
||||
206
examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp
Normal file
206
examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp
Normal file
@@ -0,0 +1,206 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorSeq,
|
||||
class TensorPageTable,
|
||||
class TensorQL,
|
||||
class TensorQR,
|
||||
class TensorCL,
|
||||
class TensorKR,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Scale
|
||||
>
|
||||
void __global__ fmha_mla_reference_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorSeq mSeq, TensorPageTable mPT,
|
||||
TensorQL mQL, TensorQR mQR,
|
||||
TensorCL mCL, TensorKR mKR,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Scale softmax_scale) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ ElementAcc mS[];
|
||||
// ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
|
||||
|
||||
for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) {
|
||||
if (mSeq.data() != nullptr) {
|
||||
K = mSeq(idx_B);
|
||||
}
|
||||
|
||||
for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) {
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
|
||||
for (int idx_D = 0; idx_D < D_latent; idx_D++) {
|
||||
int page_idx_K = idx_K;
|
||||
int page_idx_B = idx_B;
|
||||
if (mPT.data() != nullptr) {
|
||||
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
|
||||
page_idx_K = idx_K % size<0>(mCL);
|
||||
}
|
||||
ElementAcc eQ = mQL(idx_H, idx_D, idx_B);
|
||||
ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B);
|
||||
acc += eQ * eK;
|
||||
}
|
||||
|
||||
for (int idx_D = 0; idx_D < D_rope; idx_D++) {
|
||||
int page_idx_K = idx_K;
|
||||
int page_idx_B = idx_B;
|
||||
if (mPT.data() != nullptr) {
|
||||
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
|
||||
page_idx_K = idx_K % size<0>(mCL);
|
||||
}
|
||||
ElementAcc eQ = mQR(idx_H, idx_D, idx_B);
|
||||
ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B);
|
||||
acc += eQ * eK;
|
||||
}
|
||||
mS[idx_K] = acc;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAcc maxS = -std::numeric_limits<ElementAcc>::infinity();
|
||||
for (int idx_K = 0; idx_K < K; idx_K++) {
|
||||
maxS = std::max<ElementAcc>(maxS, mS[idx_K]);
|
||||
}
|
||||
if (maxS == -std::numeric_limits<ElementAcc>::infinity()) maxS = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifndef B2B
|
||||
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
|
||||
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
|
||||
}
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAcc sum = 0;
|
||||
for (int idx_K = 0; idx_K < K; idx_K++) {
|
||||
sum += mS[idx_K];
|
||||
}
|
||||
|
||||
ElementAcc o_scale = 1.0f / sum;
|
||||
#ifdef B2B
|
||||
o_scale = 1.0;
|
||||
#endif
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_K = 0; idx_K < K; idx_K++) {
|
||||
int page_idx_K = idx_K;
|
||||
int page_idx_B = idx_B;
|
||||
if (mPT.data() != nullptr) {
|
||||
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
|
||||
page_idx_K = idx_K % size<0>(mCL);
|
||||
}
|
||||
ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B);
|
||||
ElementAcc eK = static_cast<Element>(mS[idx_K]);
|
||||
acc += eK * eV;
|
||||
}
|
||||
mO(idx_H, idx_D, idx_B) = static_cast<typename TensorO::value_type>(acc * o_scale);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorSeq,
|
||||
class TensorPageTable,
|
||||
class TensorQL,
|
||||
class TensorQR,
|
||||
class TensorCL,
|
||||
class TensorKR,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Scale
|
||||
>
|
||||
void fmha_mla_reference(
|
||||
ProblemShape problem_shape,
|
||||
TensorSeq mSeq, TensorPageTable mPT,
|
||||
TensorQL mQL, TensorQR mQR,
|
||||
TensorCL mCL, TensorKR mKR,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Scale scale) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
dim3 grid(H, B, 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16;
|
||||
cudaError_t result;
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(
|
||||
&fmha_mla_reference_kernel<ProblemShape, TensorSeq, TensorPageTable, TensorQL, TensorQR, TensorCL, TensorKR, TensorO, TensorLSE, Scale>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
throw std::runtime_error("couldn't perform smem optin");
|
||||
}
|
||||
}
|
||||
fmha_mla_reference_kernel<<<grid, block, shared_mem>>>(
|
||||
problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale);
|
||||
cudaDeviceSynchronize();
|
||||
result = cudaGetLastError();
|
||||
if (cudaSuccess != result) {
|
||||
throw std::runtime_error("couldn't execute reference");
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -178,3 +178,96 @@ void reference_abs_diff(
|
||||
max_diff = result_host[0];
|
||||
mean_diff = result_host[1] / static_cast<double>(data.size());
|
||||
}
|
||||
|
||||
template<typename Element>
|
||||
__global__ void reference_rel_diff_kernel(
|
||||
Element* data, Element* data_ref, size_t count,
|
||||
double* max_diff, double* sum_diff,
|
||||
bool print_diff ) {
|
||||
|
||||
double thread_max_diff = 0;
|
||||
double thread_sum_diff = 0;
|
||||
|
||||
__shared__ double block_max_diff;
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
|
||||
for (int i = 0; i < blockDim.x; i++) {
|
||||
if (i == threadIdx.x) {
|
||||
if (i == 0) {
|
||||
block_max_diff = thread_max_diff;
|
||||
block_sum_diff = thread_sum_diff;
|
||||
}
|
||||
else {
|
||||
block_max_diff = fmax(block_max_diff, thread_max_diff);
|
||||
block_sum_diff += thread_sum_diff;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
atomicAdd(sum_diff, block_sum_diff);
|
||||
|
||||
for (;;) {
|
||||
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
|
||||
double prev_diff = reinterpret_cast<double const&>(prev);
|
||||
double new_max_diff = fmax(block_max_diff, prev_diff);
|
||||
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
|
||||
if (found == prev) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Element>
|
||||
void reference_rel_diff(
|
||||
DeviceAllocation<Element> const& data,
|
||||
DeviceAllocation<Element> const& data_ref,
|
||||
double& max_diff, double& mean_diff) {
|
||||
|
||||
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
|
||||
|
||||
DeviceAllocation<double> result;
|
||||
result.reset(2);
|
||||
assert(data.size() == data_ref.size());
|
||||
|
||||
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Memset failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 block(256, 1, 1);
|
||||
dim3 grid(1024, 1, 1);
|
||||
reference_rel_diff_kernel<<<block, grid>>>(
|
||||
data.get(), data_ref.get(), data.size(),
|
||||
result.get(), result.get() + 1, kPrintDiff);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Difference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
double result_host[2];
|
||||
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Copy failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
max_diff = result_host[0];
|
||||
mean_diff = result_host[1] / static_cast<double>(data.size());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,554 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture.
|
||||
This kernel is optimized for the GeForce RTX 50 series GPUs.
|
||||
|
||||
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions:
|
||||
* mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.
|
||||
Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution.
|
||||
|
||||
The kernel leverages:
|
||||
1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Block Scaled Sparse Tensor Core MMA Instructions
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
|
||||
Usage:
|
||||
$ ./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for B matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 16; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// E matrix configuration. Note, E is used to represent metadata tensor.
|
||||
using ElementE = uint8_t; // Element type for E matrix operand
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120; // Kernel schedule policy
|
||||
// Kernel Perf config
|
||||
using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelScheduleType // Mainloop schedule policy
|
||||
>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutE layout_E;
|
||||
uint64_t seed;
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A_Decompressed;
|
||||
cutlass::HostTensor<ElementE, cutlass::layout::PackedVectorLayout> block_E;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
bool help;
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
out << "80a_blackwell_geforce_mxfp8_bf16_sparse_gemm\n\n"
|
||||
<< " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
return out;
|
||||
}
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
};
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
/// Initialize blocks that released to sparse Matrix A and its metadata E
|
||||
bool initialize_sparse_blocks(const Options &options) {
|
||||
auto workload = make_shape(options.m,
|
||||
options.n,
|
||||
options.k,
|
||||
1);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
/// Alias SparseConfig and Compressor
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
cute::Shape<int, int, int, int>,
|
||||
ElementA::DataType,
|
||||
LayoutATag,
|
||||
SparseConfig>;
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
cute::Shape<int, int, int, int>,
|
||||
ElementA::DataType,
|
||||
LayoutATag,
|
||||
SparseConfig,
|
||||
cutlass::arch::Sm120>;
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
/// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs
|
||||
CompressorUtility compressor_utility(workload, stride_A);
|
||||
// Aligned M K dimension size for A and E
|
||||
int aligned_m_e = compressor_utility.get_metadata_m_physical();
|
||||
int aligned_k_e = compressor_utility.get_metadata_k_physical();
|
||||
int aligned_m_a = compressor_utility.get_tensorA_m_physical();
|
||||
int aligned_k_a = compressor_utility.get_tensorA_k_physical();
|
||||
/// Layout A and E
|
||||
layout_A = SparseConfig::fill_layoutA(workload);
|
||||
layout_E = SparseConfig::fill_layoutE(workload);
|
||||
|
||||
block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a));
|
||||
block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e));
|
||||
block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k));
|
||||
initialize_block(block_A_Decompressed.host_view(), seed + 2020);
|
||||
compressor_utility.structure_sparse_zero_mask_fill(
|
||||
block_A_Decompressed.host_data(), static_cast<int>(seed + 2021));
|
||||
block_A_Decompressed.sync_device();
|
||||
|
||||
/// Use compressor kernel to generate compressed Matrix A and E
|
||||
cutlass::Status status { cutlass::Status::kSuccess };
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A_Decompressed.device_data(),
|
||||
stride_A,
|
||||
block_A.device_data(),
|
||||
block_E.device_data()},
|
||||
{hw_info}
|
||||
};
|
||||
|
||||
// Compress A and E
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.run();
|
||||
auto result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
block_A.sync_host();
|
||||
block_E.sync_host();
|
||||
return true;
|
||||
}
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
// Initial A, E(metadata) and A_compressed blocks
|
||||
if(!initialize_sparse_blocks(options)) return false;
|
||||
|
||||
// Define B, C and D blocks
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
// Define SFA and SFB tensors layouts
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
return true;
|
||||
}
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), layout_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_E.device_data(), layout_E,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
return arguments;
|
||||
}
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
return passed;
|
||||
}
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
// Initialization
|
||||
if(!initialize(options))
|
||||
{
|
||||
std::cerr << " Initialization failed! " << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaDeviceSynchronize();
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 120.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
Options options;
|
||||
options.parse(argc, args);
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
return 0;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,578 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture.
|
||||
This kernel is optimized for the GeForce RTX 50 series GPUs.
|
||||
|
||||
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions:
|
||||
* mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.
|
||||
Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution.
|
||||
|
||||
The kernel leverages:
|
||||
1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Block Scaled Sparse Tensor Core MMA Instructions
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
|
||||
Usage:
|
||||
$ ./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::ColumnMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::ColumnMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int outputVectorSize = 32; // Vector size for D matrix
|
||||
using outputScaleFactor = cutlass::float_ue4m3_t; // Scale factor type for D matrix
|
||||
// E matrix configuration. Note, E is used to represent metadata tensor.
|
||||
using ElementE = uint8_t; // Element type for E matrix operand
|
||||
// Kernel functional config
|
||||
using ElementCompute = float; // Element type for computation inside mainloop and epilogue
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120; // Kernel schedule policy
|
||||
// Kernel Perf config
|
||||
using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120, // Epilogue schedule policy
|
||||
cutlass::epilogue::fusion::LinCombBlockScaleFactor< // Epilogue fusion to generate nvfp4 output
|
||||
outputVectorSize, ElementD, ElementAccumulator, outputScaleFactor, LayoutDTag, ElementC>
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelScheduleType // Mainloop schedule policy
|
||||
>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<outputVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutSFD layout_SFD;
|
||||
LayoutE layout_E;
|
||||
uint64_t seed;
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A_Decompressed;
|
||||
cutlass::HostTensor<ElementE, cutlass::layout::PackedVectorLayout> block_E;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_SFD;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
bool help;
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
out << "80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm\n\n"
|
||||
<< " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
return out;
|
||||
}
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
};
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
/// Initialize blocks that released to sparse Matrix A and its metadata E
|
||||
bool initialize_sparse_blocks(const Options &options) {
|
||||
auto workload = make_shape(options.m,
|
||||
options.n,
|
||||
options.k,
|
||||
1);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
/// Alias SparseConfig and Compressor
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
cute::Shape<int, int, int, int>,
|
||||
ElementA::DataType,
|
||||
LayoutATag,
|
||||
SparseConfig>;
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
cute::Shape<int, int, int, int>,
|
||||
ElementA::DataType,
|
||||
LayoutATag,
|
||||
SparseConfig,
|
||||
cutlass::arch::Sm120>;
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
/// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs
|
||||
CompressorUtility compressor_utility(workload, stride_A);
|
||||
// Aligned M K dimension size for A and E
|
||||
int aligned_m_e = compressor_utility.get_metadata_m_physical();
|
||||
int aligned_k_e = compressor_utility.get_metadata_k_physical();
|
||||
int aligned_m_a = compressor_utility.get_tensorA_m_physical();
|
||||
int aligned_k_a = compressor_utility.get_tensorA_k_physical();
|
||||
/// Layout A and E
|
||||
layout_A = SparseConfig::fill_layoutA(workload);
|
||||
layout_E = SparseConfig::fill_layoutE(workload);
|
||||
|
||||
block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a));
|
||||
block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e));
|
||||
block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k));
|
||||
initialize_block(block_A_Decompressed.host_view(), seed + 2020);
|
||||
compressor_utility.structure_sparse_zero_mask_fill(
|
||||
block_A_Decompressed.host_data(), static_cast<int>(seed + 2021));
|
||||
block_A_Decompressed.sync_device();
|
||||
|
||||
/// Use compressor kernel to generate compressed Matrix A and E
|
||||
cutlass::Status status { cutlass::Status::kSuccess };
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A_Decompressed.device_data(),
|
||||
stride_A,
|
||||
block_A.device_data(),
|
||||
block_E.device_data()},
|
||||
{hw_info}
|
||||
};
|
||||
|
||||
// Compress A and E
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.run();
|
||||
auto result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
block_A.sync_host();
|
||||
block_E.sync_host();
|
||||
return true;
|
||||
}
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
// Initial A, E(metadata) and A_compressed blocks
|
||||
if(!initialize_sparse_blocks(options)) return false;
|
||||
|
||||
// Define B, C and D blocks
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
// Define SFA and SFB tensors layouts
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_Normconst.reset(cutlass::make_Coord(1));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_Normconst.at(cutlass::make_Coord(0)) = 2;
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
block_SFD.sync_device();
|
||||
block_Normconst.sync_device();
|
||||
return true;
|
||||
}
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), layout_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_E.device_data(), layout_E,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
return arguments;
|
||||
}
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
auto tensor_SFD = cute::make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D), // TensorD
|
||||
decltype(tensor_SFD), // TensorSfD
|
||||
cute::Int<outputVectorSize>,
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
return passed;
|
||||
}
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
// Initialization
|
||||
if(!initialize(options))
|
||||
{
|
||||
std::cerr << " Initialization failed! " << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaDeviceSynchronize();
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 120.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
Options options;
|
||||
options.parse(argc, args);
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
return 0;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
41
examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt
Normal file
41
examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
|
||||
cutlass_example_add_executable(
|
||||
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
|
||||
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm
|
||||
80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
@@ -0,0 +1,869 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Distributed GEMM (DistGEMM) for Blackwell.
|
||||
|
||||
This example runs Tensor Parallel GEMMs using the (experimental) Distributed GEMM API in
|
||||
CUTLASS. For more information, please refer to README.md.
|
||||
|
||||
Note that Distributed GEMM assumes an any-to-any NVLink network topology.
|
||||
To check whether your device is compatible, run:
|
||||
|
||||
$ nvidia-smi topo -m
|
||||
|
||||
and make sure there's an any-to-any NVLink topology. It would look like this:
|
||||
|
||||
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
|
||||
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18
|
||||
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18
|
||||
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18
|
||||
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18
|
||||
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18
|
||||
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18
|
||||
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18
|
||||
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X
|
||||
|
||||
You should also additionally check if the driver enables peer to peer access:
|
||||
|
||||
$ nvidia-smi topo -p2p r
|
||||
|
||||
Output should be something like this:
|
||||
|
||||
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
|
||||
GPU0 X OK OK OK OK OK OK OK
|
||||
GPU1 OK X OK OK OK OK OK OK
|
||||
GPU2 OK OK X OK OK OK OK OK
|
||||
GPU3 OK OK OK X OK OK OK OK
|
||||
GPU4 OK OK OK OK X OK OK OK
|
||||
GPU5 OK OK OK OK OK X OK OK
|
||||
GPU6 OK OK OK OK OK OK X OK
|
||||
GPU7 OK OK OK OK OK OK OK X
|
||||
|
||||
It is recommended to build this target with the following flag to enable
|
||||
Grid Dependency Control instructions (GDC) in CUTLASS:
|
||||
- CUTLASS_ENABLE_GDC_FOR_SM100
|
||||
|
||||
Example:
|
||||
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
|
||||
|
||||
$ cd examples/82_blackwell_distributed_gemm
|
||||
|
||||
$ make
|
||||
|
||||
$ ./82_blackwell_distributed_gemm
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
// Distributed GEMM headers
|
||||
#include "cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp"
|
||||
#include "cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp"
|
||||
#include "cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// Distributed GEMM helpers
|
||||
#include "dist_gemm_helpers.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Distributed GEMM configuration
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// TP size (= number of processors/GPUs)
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
//
|
||||
// * All Gather + GEMM:
|
||||
// * AllGather1D_TilingCD_RotatingA
|
||||
// * AllGather1D_TilingCD_RotatingB
|
||||
//
|
||||
// * GEMM + Reduce Scatter:
|
||||
// * ReduceScatter1D_TilingA_RotatingC
|
||||
// * ReduceScatter1D_TilingB_RotatingC
|
||||
|
||||
using DistSchedule = cutlass::distributed::schedules::AllGather1D_TilingCD_RotatingA<TP>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutD = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_256,_128>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_1,_1>;
|
||||
// Shape of the tile computed by each SM
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _128>;
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
// We're going to use the single-device GEMM as reference
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Instantiate Distributed GEMM kernel
|
||||
using DistGemmKernel = cutlass::distributed::kernel::DistributedGemmKernelWrapper<
|
||||
GemmKernel,
|
||||
DistSchedule
|
||||
>;
|
||||
using DistGemm = cutlass::distributed::device::DistributedGemmUniversalAdapter<DistGemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
using HostTensorA = typename cutlass::HostTensor<ElementA, LayoutA>;
|
||||
using HostTensorB = typename cutlass::HostTensor<ElementB, LayoutB>;
|
||||
using HostTensorC = typename cutlass::HostTensor<ElementC, LayoutC>;
|
||||
using HostTensorD = typename cutlass::HostTensor<ElementD, LayoutD>;
|
||||
|
||||
// Reference GEMM tensors
|
||||
HostTensorA tensor_A;
|
||||
HostTensorB tensor_B;
|
||||
HostTensorC tensor_C;
|
||||
HostTensorD tensor_D;
|
||||
HostTensorD tensor_ref_D;
|
||||
|
||||
// DistGEMM tensors (multi-device)
|
||||
HostTensorA tensor_A_arr[TP_];
|
||||
HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 100;
|
||||
int warmup_iterations = 10;
|
||||
int m = 16384, n = 106496, k = 16384, l = 1;
|
||||
float eps = 0.f;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("warmup-iterations", warmup_iterations);
|
||||
cmd.get_cmd_line_argument("eps", eps);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "82_blackwell_distributed_gemm\n\n"
|
||||
<< " Blackwell Distributed GEMM (DistGEMM). \n"
|
||||
<< " For more details please refer to the source file.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch) of the GEMM (default: 1)\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (default: 1.0)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (default: 0.0)\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform (default: 100)\n"
|
||||
<< " --warmup-iterations=<int> Number of warmup iterations prior to profiling (default: 10)\n"
|
||||
<< " --eps=<f32> Threshold for error compared to reference "
|
||||
<< "GEMM (default: 0.0)\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "82_blackwell_distributed_gemm" << " --m=16384 --n=106496 --k=16384 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in TFLOP/s
|
||||
double tflops(double runtime_s) const {
|
||||
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l / TP_;
|
||||
double tflop = double(flop) / double(1.0e12);
|
||||
return tflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
double avg_runtime_ms;
|
||||
double tflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double tflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), tflops(tflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed,
|
||||
bool is_device_tensor = false) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits <= 16) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
if (is_device_tensor) {
|
||||
using Real = typename cutlass::RealType<Element>::Type;
|
||||
cutlass::reference::device::TensorFillRandomUniform(
|
||||
view, seed, static_cast<Real>(scope_max), static_cast<Real>(scope_min), 0);
|
||||
cudaDeviceSynchronize();
|
||||
} else {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
|
||||
|
||||
// Setup (reference) GEMM tensors
|
||||
auto shape_A = cute::select<0,2,3>(problem_shape);
|
||||
auto shape_B = cute::select<1,2,3>(problem_shape);
|
||||
auto shape_C = cute::select<0,1,3>(problem_shape);
|
||||
auto shape_D = cute::select<0,1,3>(problem_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A);
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
|
||||
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
|
||||
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.device_view(), seed + 2022, /* is_device_tensor = */ true);
|
||||
initialize_tensor(tensor_B.device_view(), seed + 2023, /* is_device_tensor = */ true);
|
||||
initialize_tensor(tensor_C.device_view(), seed + 2024, /* is_device_tensor = */ true);
|
||||
|
||||
tensor_A.sync_host();
|
||||
tensor_B.sync_host();
|
||||
tensor_C.sync_host();
|
||||
tensor_D.sync_host();
|
||||
tensor_ref_D.sync_host();
|
||||
|
||||
// Set up DistGEMM tensors
|
||||
auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape);
|
||||
auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape);
|
||||
auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape);
|
||||
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
|
||||
|
||||
auto a_coord_device = cutlass::make_Coord(size(local_shape_A), 1);
|
||||
auto b_coord_device = cutlass::make_Coord(size(local_shape_B), 1);
|
||||
auto c_coord_device = cutlass::make_Coord(size(local_shape_C), 1);
|
||||
|
||||
int primary_device_idx;
|
||||
CUDA_CHECK(cudaGetDevice(&primary_device_idx));
|
||||
|
||||
// Enable any-to-any access
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
int can_access;
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
for (int peer_idx = 0; peer_idx < TP_; ++peer_idx) {
|
||||
if (peer_idx != device_idx) {
|
||||
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device_idx, peer_idx));
|
||||
if (not can_access) {
|
||||
std::cerr << "FAILURE: Device " << device_idx << " can't access device " << peer_idx << "." <<
|
||||
std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
CUDA_CHECK(cudaDeviceEnablePeerAccess(peer_idx, 0));
|
||||
}
|
||||
}
|
||||
|
||||
tensor_A_arr[device_idx].resize(a_coord_device);
|
||||
tensor_B_arr[device_idx].resize(b_coord_device);
|
||||
tensor_C_arr[device_idx].resize(c_coord_device);
|
||||
tensor_D_arr[device_idx].resize(c_coord_device);
|
||||
}
|
||||
CUDA_CHECK(cudaSetDevice(primary_device_idx));
|
||||
}
|
||||
|
||||
/// Commandline options -> Gemm/DistGemm Arguments
|
||||
using GemmArguments = typename Gemm::Arguments;
|
||||
GemmArguments gemm_args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{static_cast<ElementCompute>(options.alpha), static_cast<ElementCompute>(options.beta)},
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_ref_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
using DistGemmArguments = typename DistGemm::Arguments;
|
||||
DistGemmArguments dist_gemm_args_from_options(
|
||||
const Options &options,
|
||||
int device_idx,
|
||||
cudaStream_t stream) {
|
||||
|
||||
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
|
||||
|
||||
auto global_A = cute::make_tensor(tensor_A.device_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto global_B = cute::make_tensor(tensor_B.device_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto global_C = cute::make_tensor(tensor_C.device_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
|
||||
auto global_A_device_slice = DistSchedule::get_device_slice_A(global_A, device_idx);
|
||||
auto global_B_device_slice = DistSchedule::get_device_slice_B(global_B, device_idx);
|
||||
auto global_C_device_slice = DistSchedule::get_device_slice_C(global_C, device_idx);
|
||||
|
||||
auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape);
|
||||
auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape);
|
||||
auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape);
|
||||
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
|
||||
|
||||
auto local_stride_A = cutlass::make_cute_packed_stride(StrideA{}, local_shape_A);
|
||||
auto local_stride_B = cutlass::make_cute_packed_stride(StrideB{}, local_shape_B);
|
||||
auto local_stride_C = cutlass::make_cute_packed_stride(StrideC{}, local_shape_C);
|
||||
auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D);
|
||||
|
||||
auto local_A = cute::make_tensor(
|
||||
tensor_A_arr[device_idx].device_data(),
|
||||
make_layout(local_shape_A, local_stride_A));
|
||||
auto local_B = cute::make_tensor(
|
||||
tensor_B_arr[device_idx].device_data(),
|
||||
make_layout(local_shape_B, local_stride_B));
|
||||
auto local_C = cute::make_tensor(
|
||||
tensor_C_arr[device_idx].device_data(),
|
||||
make_layout(local_shape_C, local_stride_C));
|
||||
auto local_D = cute::make_tensor(
|
||||
tensor_D_arr[device_idx].device_data(),
|
||||
make_layout(local_shape_D, local_stride_D));
|
||||
|
||||
// Copy over tensor tiles for the first iteration
|
||||
cutlass::device_copy(global_A_device_slice, local_A, stream);
|
||||
cutlass::device_copy(global_B_device_slice, local_B, stream);
|
||||
cutlass::device_copy(global_C_device_slice, local_C, stream);
|
||||
|
||||
DistGemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // mode
|
||||
problem_shape, // problem shape
|
||||
{
|
||||
reinterpret_cast<const ElementA*>(local_A.data()),
|
||||
local_A.stride(),
|
||||
reinterpret_cast<const ElementB*>(local_B.data()),
|
||||
local_B.stride()
|
||||
}, // mainloop
|
||||
{
|
||||
{ // epilogue.thread
|
||||
static_cast<ElementCompute>(options.alpha),
|
||||
static_cast<ElementCompute>(options.beta)
|
||||
},
|
||||
reinterpret_cast<const ElementC*>(local_C.data()),
|
||||
local_C.stride(),
|
||||
reinterpret_cast<ElementD*>(local_D.data()),
|
||||
local_D.stride(),
|
||||
}, // epilogue
|
||||
{}, // hw_info
|
||||
{} // scheduler
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
// Gathers results, moves back to the original full-sized D tensor on the primary device.
|
||||
void gather_results(const Options &options, int device_idx, cudaStream_t stream = nullptr) {
|
||||
|
||||
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
|
||||
|
||||
// Global dest
|
||||
auto global_D = cute::make_tensor(tensor_D.device_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto global_D_device_slice = DistSchedule::get_device_slice_D(global_D, device_idx);
|
||||
|
||||
// Device_idx local dest
|
||||
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
|
||||
auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D);
|
||||
auto local_D = cute::make_tensor(
|
||||
tensor_D_arr[device_idx].device_data(),
|
||||
make_layout(local_shape_D, local_stride_D)
|
||||
);
|
||||
|
||||
// Copy to global dest
|
||||
cutlass::device_copy(local_D, global_D_device_slice, stream);
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
tensor_D.sync_host();
|
||||
tensor_ref_D.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
if (options.eps == 0.f) {
|
||||
passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
} else {
|
||||
double err = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_D.host_view(),
|
||||
tensor_ref_D.host_view());
|
||||
passed = err < 1e-5;
|
||||
}
|
||||
|
||||
if (options.m <= 64 && options.n <= 64) {
|
||||
std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n";
|
||||
std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options &options) {
|
||||
|
||||
int primary_device_idx;
|
||||
cudaError_t device_get_result = cudaGetDevice(&primary_device_idx);
|
||||
if (device_get_result != cudaSuccess) {
|
||||
throw std::runtime_error("cudaGetDevice() failed");
|
||||
}
|
||||
|
||||
initialize(options);
|
||||
|
||||
// Reference single-GPU GEMM
|
||||
Gemm reference_gemm;
|
||||
cutlass::device_memory::allocation<uint8_t> reference_workspace;
|
||||
|
||||
auto reference_arguments = gemm_args_from_options(options);
|
||||
size_t reference_workspace_size = Gemm::get_workspace_size(reference_arguments);
|
||||
reference_workspace = cutlass::device_memory::allocation<uint8_t>(reference_workspace_size);
|
||||
|
||||
CUTLASS_CHECK(reference_gemm.can_implement(reference_arguments));
|
||||
CUTLASS_CHECK(reference_gemm.initialize(reference_arguments, reference_workspace.get()));
|
||||
CUTLASS_CHECK(reference_gemm.run());
|
||||
|
||||
using ElementBarrier = typename DistGemm::ElementBarrier;
|
||||
using ElementFlag = typename DistGemmKernel::ElementFlag;
|
||||
|
||||
// Set up per-device streams
|
||||
cudaStream_t stream_arr[TP_];
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
|
||||
// Create stream
|
||||
CUDA_CHECK(cudaStreamCreate(&stream_arr[device_idx]));
|
||||
}
|
||||
|
||||
// Instantiate DistGEMM
|
||||
DistGemm dist_gemm_arr[TP_]; // Distributed GEMM array for multiple devices
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_arr[TP_];
|
||||
cutlass::device_memory::allocation<uint8_t> exclusive_workspace_arr[TP_];
|
||||
|
||||
// Cross-device workspace pointer array for gemm.initialize()
|
||||
void * workspace_ptr_arr[TP_];
|
||||
void * exclusive_workspace_ptr_arr[TP_];
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
DistGemmArguments arguments_[TP_];
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
|
||||
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
|
||||
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
|
||||
|
||||
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
|
||||
exclusive_workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(exclusive_workspace_size);
|
||||
|
||||
// Throw workspace pointers into arrays for gemm.initialize()
|
||||
workspace_ptr_arr[device_idx] = workspace_arr[device_idx].get();
|
||||
exclusive_workspace_ptr_arr[device_idx] = exclusive_workspace_arr[device_idx].get();
|
||||
|
||||
// Zero out exclusive workspace
|
||||
cudaMemsetAsync(exclusive_workspace_ptr_arr[device_idx], 0, exclusive_workspace_size, stream_arr[device_idx]);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(dist_gemm_arr[device_idx].can_implement(arguments_[device_idx]));
|
||||
|
||||
#if defined(CUTLASS_ENABLE_GDC_FOR_SM100)
|
||||
bool launch_with_pdl = true;
|
||||
#else
|
||||
bool launch_with_pdl = false;
|
||||
#endif
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(dist_gemm_arr[device_idx].initialize(
|
||||
arguments_,
|
||||
workspace_ptr_arr,
|
||||
exclusive_workspace_ptr_arr,
|
||||
device_idx,
|
||||
stream_arr[device_idx],
|
||||
launch_with_pdl
|
||||
));
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
std::cout << std::endl << " running DistGEMM..." << std::endl;
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
|
||||
}
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx]));
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
gather_results(options, device_idx);
|
||||
}
|
||||
|
||||
std::cout << " running DistGEMM finished without runtime errors" << std::endl;
|
||||
|
||||
//// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << std::endl << " Disposition (eps: " << options.eps << "): " <<
|
||||
(result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
float elapsed_ms = 0.f;
|
||||
|
||||
// Warmup
|
||||
std::cout << " Warming up for " << options.warmup_iterations << " iterations." << std::endl;
|
||||
for (int warmup_iter = 0; warmup_iter < options.warmup_iterations; ++warmup_iter) {
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
|
||||
}
|
||||
}
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx]));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(primary_device_idx));
|
||||
|
||||
// Benchmark
|
||||
std::cout << " Profiling for " << options.iterations << " iterations." << std::endl;
|
||||
using AtomicBoolean = cuda::atomic<bool>;
|
||||
AtomicBoolean* atomic_flag_ptr;
|
||||
CUDA_CHECK(cudaHostAlloc(&atomic_flag_ptr, sizeof(AtomicBoolean), cudaHostAllocPortable));
|
||||
atomic_flag_ptr->store(false);
|
||||
|
||||
cutlass::DistGpuTimer<TP_> timer;
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
cutlass::delay_kernel<<<1, 1, 0, stream_arr[device_idx]>>>(atomic_flag_ptr);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
timer.start(device_idx, stream_arr[device_idx]);
|
||||
}
|
||||
|
||||
atomic_flag_ptr->store(true);
|
||||
|
||||
for (int profile_iter = 0; profile_iter < options.iterations; ++profile_iter) {
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
|
||||
}
|
||||
}
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
CUDA_CHECK(cudaSetDevice(device_idx));
|
||||
timer.stop(device_idx, stream_arr[device_idx]);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(primary_device_idx));
|
||||
|
||||
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
|
||||
elapsed_ms = max(elapsed_ms, timer.elapsed_millis(device_idx));
|
||||
}
|
||||
|
||||
// Compute average runtime and TFLOPs.
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0);
|
||||
result.tflops = options.tflops(avg_runtime_s);
|
||||
|
||||
auto [local_M, local_N, local_K, local_L] = DistSchedule::get_local_gemm_shape(
|
||||
cute::make_tuple(options.m, options.n, options.k, options.l));
|
||||
|
||||
std::cout << std::endl;
|
||||
std::cout << " TP: " << TP::value << std::endl;
|
||||
std::cout << " Problem Size: " <<
|
||||
options.m << " x " <<
|
||||
options.n << " x " <<
|
||||
options.k << " x " <<
|
||||
options.l << std::endl;
|
||||
std::cout << " Local GEMM Problem Size: " <<
|
||||
local_M << " x " <<
|
||||
local_N << " x " <<
|
||||
local_K << " x " <<
|
||||
local_L<< std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " TFLOPS: " << result.tflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
||||
// and must have compute capability at least 90.
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
int num_devices;
|
||||
CUDA_CHECK(cudaGetDeviceCount(&num_devices));
|
||||
if (num_devices < TP_) {
|
||||
std::cerr << "Distributed GEMM is compiled with TP = " << TP::value << ", but " <<
|
||||
"found only " << num_devices << " devices." <<
|
||||
std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability 100), "
|
||||
<< "got compute capability " << props.major * 10 + props.minor << "."
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
run(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
32
examples/82_blackwell_distributed_gemm/CMakeLists.txt
Normal file
32
examples/82_blackwell_distributed_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
82_blackwell_distributed_gemm
|
||||
82_blackwell_distributed_gemm.cu
|
||||
)
|
||||
37
examples/82_blackwell_distributed_gemm/README.md
Normal file
37
examples/82_blackwell_distributed_gemm/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Blackwell Distributed GEMM
|
||||
|
||||
This example implements Tensor Parallel GEMMs for the Hopper architecture with the experimental
|
||||
[Distributed GEMM](../../include/cutlass/experimental/distributed) API in CUTLASS.
|
||||
|
||||
This example requires Blackwell GPUs with an any-to-any NVLink network.
|
||||
Please refer to [REQUIREMENTS.md](REQUIREMENTS.md) for more information.
|
||||
|
||||
By default, the example assumes 8 GPUs (TP=8) and runs an All Gather + GEMM operation, which rotates
|
||||
operand A. To run with a different number of GPUs or schedule, please refer to
|
||||
[82_blackwell_distributed_gemm.cu](82_blackwell_distributed_gemm.cu).
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
Command line arguments are mostly similar to other examples:
|
||||
|
||||
```
|
||||
--m=<int> Sets the M extent of the GEMM
|
||||
--n=<int> Sets the N extent of the GEMM
|
||||
--k=<int> Sets the K extent of the GEMM
|
||||
--l=<int> Sets the L extent (batch) of the GEMM (default: 1)
|
||||
--alpha=<f32> Epilogue scalar alpha (default: 1.0)
|
||||
--beta=<f32> Epilogue scalar beta (default: 0.0)
|
||||
--iterations=<int> Number of profiling iterations to perform (default: 100)
|
||||
--warmup-iterations=<int> Number of warmup iterations prior to profiling (default: 10)
|
||||
--eps=<f32> Threshold for error compared to reference GEMM (default: 0.0)
|
||||
```
|
||||
|
||||
Sample run command:
|
||||
|
||||
```bash
|
||||
./82_blackwell_distributed_gemm --m=16384 --n=106496 --k=16384 --warmup-iterations=10 --iterations=100
|
||||
```
|
||||
|
||||
This example follows the [Hopper example](../65_distributed_gemm/) very closely, and only differs in the base GEMM kernel. For
|
||||
more information you can refer to [that example](../65_distributed_gemm/README.md).
|
||||
86
examples/82_blackwell_distributed_gemm/REQUIREMENTS.md
Normal file
86
examples/82_blackwell_distributed_gemm/REQUIREMENTS.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Blackwell Distributed GEMM
|
||||
|
||||
## Requirements
|
||||
|
||||
### Build
|
||||
Make sure to set up CUTLASS with
|
||||
support for [Programmatic Dependent Launch (PDL)](../../media/docs/dependent_kernel_launch.md),
|
||||
that is with the `CUTLASS_ENABLE_GDC_FOR_SM100` flag.
|
||||
|
||||
```bash
|
||||
cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
|
||||
```
|
||||
|
||||
### Minimum software
|
||||
|
||||
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
|
||||
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
|
||||
CUDA graph APIs.
|
||||
|
||||
### Hardware / driver settings
|
||||
|
||||
This example requires Blackwell GPUs with NVLink network.
|
||||
|
||||
If you're not sure, first run the following command and make sure your GPU
|
||||
compute capability is 10.0:
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name,compute_cap --format=csv
|
||||
```
|
||||
|
||||
Sample output:
|
||||
|
||||
```
|
||||
name, compute_cap
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
NVIDIA B200, 10.0
|
||||
```
|
||||
|
||||
|
||||
Then you should make sure there is an NVLink network by checking the GPU network topology,
|
||||
and making sure there's `NV*` links between every pair of GPUs:
|
||||
|
||||
```bash
|
||||
nvidia-smi topo -m
|
||||
```
|
||||
|
||||
Sample output:
|
||||
|
||||
```
|
||||
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
|
||||
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18
|
||||
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18
|
||||
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18
|
||||
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18
|
||||
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18
|
||||
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18
|
||||
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18
|
||||
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X
|
||||
```
|
||||
|
||||
Finally, check if the driver enables peer to peer access, which should usually be the case,
|
||||
but it's good to check anyway:
|
||||
|
||||
```bash
|
||||
nvidia-smi topo -p2p r
|
||||
```
|
||||
|
||||
Sample output:
|
||||
|
||||
```
|
||||
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
|
||||
GPU0 X OK OK OK OK OK OK OK
|
||||
GPU1 OK X OK OK OK OK OK OK
|
||||
GPU2 OK OK X OK OK OK OK OK
|
||||
GPU3 OK OK OK X OK OK OK OK
|
||||
GPU4 OK OK OK OK X OK OK OK
|
||||
GPU5 OK OK OK OK OK X OK OK
|
||||
GPU6 OK OK OK OK OK OK X OK
|
||||
GPU7 OK OK OK OK OK OK OK X
|
||||
```
|
||||
607
examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu
Normal file
607
examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu
Normal file
@@ -0,0 +1,607 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP16 sparse GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
|
||||
|
||||
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
|
||||
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
|
||||
|
||||
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
|
||||
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
|
||||
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm --m=8192 --n=8192 --k=8192
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 2 * 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
|
||||
|
||||
// E matrix config
|
||||
using ElementE = cute::uint8_t;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = float; // Element type for D matrix operand
|
||||
using ElementC = float; // Element type for C matrix operand
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand
|
||||
using LayoutTagD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassSparseTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_1,_1>;
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementD, LayoutTagD, AlignmentD,
|
||||
cutlass::epilogue::TmaWarpSpecialized2Sm
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutTagA, AlignmentA,
|
||||
ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
|
||||
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
ElementB,
|
||||
LayoutTagB,
|
||||
ElementC,
|
||||
LayoutTagC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
// Layouts
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Compressor
|
||||
//
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig,
|
||||
ArchTag>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
LayoutA layout_A;
|
||||
LayoutE layout_E;
|
||||
StrideA stride_A;
|
||||
StrideA stride_A_compressed;
|
||||
StrideE stride_E;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
ProblemShape problem_shape;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A_compressed;
|
||||
cutlass::DeviceAllocation<typename Gemm::CollectiveMainloop::ElementE> block_E;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k, l;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(8192), n(8192), k(8192), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "83_blackwell_sparse_gemm\n\n"
|
||||
<< " Blackwell FP16 Sparse GEMM example.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "83_blackwell_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
}
|
||||
else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
bool sparsify_and_compress()
|
||||
{
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
|
||||
// TensorE
|
||||
// In unit of ElementE (uint8_t), after alignment requirement
|
||||
// M-dim: TensorEAtom_M alignment
|
||||
// K-dim: TensorEAtom_K alignment
|
||||
int KAlignedE = compressor_utility.get_metadata_k_physical();
|
||||
int MAlignedE = compressor_utility.get_metadata_m_physical();
|
||||
|
||||
// TensorA Compressed
|
||||
// In unit of ElementARaw, after alignment requirement
|
||||
// M-dim: TMA alignment
|
||||
// K-dim: TMA alignment
|
||||
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
|
||||
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
|
||||
|
||||
block_A_compressed.reset(M * KAlignedAC * L);
|
||||
block_E.reset(MAlignedE * KAlignedE * L);
|
||||
|
||||
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L));
|
||||
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L));
|
||||
|
||||
// Random 50% fill zero is performed on host
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
|
||||
compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast<int>(seed + 2024));
|
||||
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments {
|
||||
problem_shape,
|
||||
{ block_A.get(),
|
||||
stride_A,
|
||||
block_A_compressed.get(),
|
||||
block_E.get() },
|
||||
{hw_info} };
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status {cutlass::Status::kSuccess };
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
|
||||
// Compress row A and get A_compress and E
|
||||
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
|
||||
if (not sparsify_and_compress()) {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Build the compressed/metadata layouts
|
||||
layout_A = SparseConfig::fill_layoutA(problem_shape);
|
||||
layout_E = SparseConfig::fill_layoutE(problem_shape);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_shape,
|
||||
{ block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E },
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
auto init_pass = initialize(options);
|
||||
if (not init_pass) {
|
||||
std::cout << "Initialization failure" << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (not result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (not (props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
38
examples/83_blackwell_sparse_gemm/CMakeLists.txt
Normal file
38
examples/83_blackwell_sparse_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
83_blackwell_sparse_gemm
|
||||
83_blackwell_sparse_gemm.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
@@ -0,0 +1,693 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 Sparse GEMM on the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
|
||||
on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
|
||||
and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 83_blackwell_sparse_gemm, this kernel leverages:
|
||||
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementAPair = cutlass::nv_float4_t<ElementA>; // Element type for A matrix operand
|
||||
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
|
||||
|
||||
// E matrix config
|
||||
using ElementE = cute::uint8_t;
|
||||
using LayoutTagE = LayoutTagA;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementBPair = cutlass::nv_float4_t<ElementB>; // Element type for B matrix operand
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// SF
|
||||
using ElementSF = typename ElementAPair::ScaleFactorType;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape = Shape<_256,_128,_256>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape = Shape<_2,_1,_1>;
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementD, LayoutTagD, AlignmentD,
|
||||
cutlass::epilogue::TmaWarpSpecialized2SmNvf4
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementAPair, LayoutTagA, AlignmentA,
|
||||
ElementBPair, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
|
||||
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
//
|
||||
// Blockscale
|
||||
//
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;
|
||||
using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF;
|
||||
using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom;
|
||||
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Compressor
|
||||
//
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig,
|
||||
ArchTag>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideA stride_A_compressed;
|
||||
StrideE stride_E;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
|
||||
LayoutA layout_A;
|
||||
LayoutE layout_E;
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
|
||||
typename LayoutTagA::Stride stride_factor_A;
|
||||
typename LayoutTagB::Stride stride_factor_B;
|
||||
typename LayoutTagE::Stride stride_factor_E;
|
||||
typename LayoutTagC::Stride stride_factor_C;
|
||||
typename LayoutTagD::Stride stride_factor_D;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
ProblemShape problem_shape;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A;
|
||||
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A_compressed;
|
||||
cutlass::HostTensor<ElementE, LayoutTagE> tensor_E;
|
||||
cutlass::HostTensor<ElementB, LayoutTagB> tensor_B;
|
||||
cutlass::HostTensor<ElementC, LayoutTagC> tensor_C;
|
||||
cutlass::HostTensor<ElementSF, LayoutTagA> tensor_SFA;
|
||||
cutlass::HostTensor<ElementSF, LayoutTagB> tensor_SFB;
|
||||
cutlass::HostTensor<ElementD, LayoutTagD> tensor_D;
|
||||
cutlass::HostTensor<ElementD, LayoutTagD> reference_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k, l;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "84a_blackwell_nvfp4_bf16_sparse_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>){
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(const Options &options) {
|
||||
|
||||
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
|
||||
|
||||
// * Get A B C D size
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
layout_A = SparseConfig::fill_layoutA(problem_shape);
|
||||
layout_E = SparseConfig::fill_layoutE(problem_shape);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape);
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape);
|
||||
|
||||
// * Get ACompress & E size
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
|
||||
// TensorE
|
||||
// In unit of ElementE (uint8_t), after alignment requirement
|
||||
// M-dim: TensorEAtom_M alignment
|
||||
// K-dim: TensorEAtom_K alignment
|
||||
int KAlignedE = compressor_utility.get_metadata_k_physical();
|
||||
int MAlignedE = compressor_utility.get_metadata_m_physical();
|
||||
|
||||
// TensorA Compressed
|
||||
// In unit of ElementARaw, after alignment requirement
|
||||
// M-dim: TMA alignment
|
||||
// K-dim: TMA alignment
|
||||
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
|
||||
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
|
||||
|
||||
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l));
|
||||
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l));
|
||||
|
||||
// * Get SFA & SFB size
|
||||
auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{})));
|
||||
auto m_blks = cutlass::ceil_div(options.m, Blk_MN{});
|
||||
auto n_blks = cutlass::ceil_div(options.n, Blk_MN{});
|
||||
|
||||
// * Allocate Tensor
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE);
|
||||
auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto d_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
|
||||
auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
|
||||
|
||||
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
|
||||
tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_comp_coord, stride_factor_A));
|
||||
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
|
||||
tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagE>::layout_factory(e_coord, stride_factor_E));
|
||||
tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
|
||||
tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D));
|
||||
reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D), false);
|
||||
tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(sfa_coord, stride_factor_A));
|
||||
tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(sfb_coord, stride_factor_B));
|
||||
|
||||
// * Random init
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2021);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2023);
|
||||
initialize_tensor(tensor_SFA.host_view(), seed + 2024);
|
||||
initialize_tensor(tensor_SFB.host_view(), seed + 2025);
|
||||
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
||||
|
||||
// * Random fill 50% A with zero
|
||||
compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast<int>(seed + 2023));
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_SFA.sync_device();
|
||||
tensor_SFB.sync_device();
|
||||
|
||||
// * Compress
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
problem_shape,
|
||||
{tensor_A.device_data(),
|
||||
stride_A,
|
||||
tensor_A_compressed.device_data(),
|
||||
tensor_E.device_data()},
|
||||
{hw_info}
|
||||
};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status {cutlass::Status::kSuccess };
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_E.sync_host();
|
||||
tensor_A_compressed.sync_host();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{
|
||||
reinterpret_cast<ArrayElementA *>(tensor_A_compressed.device_data()), layout_A,
|
||||
reinterpret_cast<ArrayElementB *>(tensor_B.device_data()), stride_B,
|
||||
tensor_E.device_data(), layout_E,
|
||||
tensor_SFA.device_data(), layout_SFA,
|
||||
tensor_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{
|
||||
{options.alpha, options.beta},
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
// Create the arguments for host reference implementation
|
||||
auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A);
|
||||
auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA);
|
||||
auto B = make_tensor(make_iterator(tensor_B.host_data()),
|
||||
make_layout(make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(B),
|
||||
decltype(SFA),
|
||||
decltype(SFB)> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
auto C = make_tensor(make_iterator(tensor_C.host_data()),
|
||||
make_layout(make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = make_tensor(make_iterator(reference_D.host_data()),
|
||||
make_layout(make_shape(options.m, options.n, options.l), stride_D));
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(C), // TensorC
|
||||
decltype(D) // TensorD
|
||||
> epilogue_params{
|
||||
options.alpha,
|
||||
options.beta,
|
||||
C,
|
||||
D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
auto init_pass = initialize(options);
|
||||
if (not init_pass) {
|
||||
std::cout << "Initialization failure" << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (not result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (not (props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,695 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled MXFP8 Sparse GEMM on the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
|
||||
on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
|
||||
and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 83_blackwell_sparse_gemm, this kernel leverages:
|
||||
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementAPair = cutlass::mx_float8_t<ElementA>; // Element type for A matrix operand
|
||||
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
|
||||
|
||||
// E matrix config
|
||||
using ElementE = cute::uint8_t;
|
||||
using LayoutTagE = LayoutTagA;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementBPair = cutlass::mx_float4_t<ElementB>; // Element type for B matrix operand
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// SF
|
||||
using ElementSF = typename ElementAPair::ScaleFactorType;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_256>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_1,_1>;
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementD, LayoutTagD, AlignmentD,
|
||||
cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementAPair, LayoutTagA, AlignmentA,
|
||||
ElementBPair, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
|
||||
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
//
|
||||
// Blockscale
|
||||
//
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;
|
||||
using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF;
|
||||
using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom;
|
||||
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Compressor
|
||||
//
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig,
|
||||
ArchTag>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideA stride_A_compressed;
|
||||
StrideE stride_E;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
|
||||
LayoutA layout_A;
|
||||
LayoutE layout_E;
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
|
||||
typename LayoutTagA::Stride stride_factor_A;
|
||||
typename LayoutTagB::Stride stride_factor_B;
|
||||
typename LayoutTagE::Stride stride_factor_E;
|
||||
typename LayoutTagC::Stride stride_factor_C;
|
||||
typename LayoutTagD::Stride stride_factor_D;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
ProblemShape problem_shape;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A;
|
||||
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A_compressed;
|
||||
cutlass::HostTensor<ElementE, LayoutTagE> tensor_E;
|
||||
cutlass::HostTensor<ElementB, LayoutTagB> tensor_B;
|
||||
cutlass::HostTensor<ElementC, LayoutTagC> tensor_C;
|
||||
cutlass::HostTensor<ElementSF, LayoutTagA> tensor_SFA;
|
||||
cutlass::HostTensor<ElementSF, LayoutTagB> tensor_SFB;
|
||||
cutlass::HostTensor<ElementD, LayoutTagD> tensor_D;
|
||||
cutlass::HostTensor<ElementD, LayoutTagD> reference_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k, l;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "84b_blackwell_mixed_mxfp8_bf16_sparse_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>){
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(const Options &options) {
|
||||
|
||||
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
|
||||
|
||||
// * Get A B C D size
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
layout_A = SparseConfig::fill_layoutA(problem_shape);
|
||||
layout_E = SparseConfig::fill_layoutE(problem_shape);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape);
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape);
|
||||
|
||||
// * Get ACompress & E size
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
|
||||
// TensorE
|
||||
// In unit of ElementE (uint8_t), after alignment requirement
|
||||
// M-dim: TensorEAtom_M alignment
|
||||
// K-dim: TensorEAtom_K alignment
|
||||
int KAlignedE = compressor_utility.get_metadata_k_physical();
|
||||
int MAlignedE = compressor_utility.get_metadata_m_physical();
|
||||
|
||||
// TensorA Compressed
|
||||
// In unit of ElementARaw, after alignment requirement
|
||||
// M-dim: TMA alignment
|
||||
// K-dim: TMA alignment
|
||||
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
|
||||
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
|
||||
|
||||
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l));
|
||||
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l));
|
||||
|
||||
// * Get SFA & SFB size
|
||||
auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{})));
|
||||
auto m_blks = cutlass::ceil_div(options.m, Blk_MN{});
|
||||
auto n_blks = cutlass::ceil_div(options.n, Blk_MN{});
|
||||
|
||||
// * Allocate Tensor
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE);
|
||||
auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto d_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
|
||||
auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{});
|
||||
|
||||
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
|
||||
tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_comp_coord, stride_factor_A));
|
||||
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
|
||||
tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagE>::layout_factory(e_coord, stride_factor_E));
|
||||
tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
|
||||
tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D));
|
||||
reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(d_coord, stride_factor_D), false);
|
||||
tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(sfa_coord, stride_factor_A));
|
||||
tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(sfb_coord, stride_factor_B));
|
||||
|
||||
// * Random init
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2021);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2023);
|
||||
initialize_tensor(tensor_SFA.host_view(), seed + 2024);
|
||||
initialize_tensor(tensor_SFB.host_view(), seed + 2025);
|
||||
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
||||
|
||||
// * Random fill 50% A with zero
|
||||
compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast<int>(seed + 2023));
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_SFA.sync_device();
|
||||
tensor_SFB.sync_device();
|
||||
|
||||
// * Compress
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
problem_shape,
|
||||
{tensor_A.device_data(),
|
||||
stride_A,
|
||||
tensor_A_compressed.device_data(),
|
||||
tensor_E.device_data()},
|
||||
{hw_info}
|
||||
};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status {cutlass::Status::kSuccess };
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
status = compressor_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_E.sync_host();
|
||||
tensor_A_compressed.sync_host();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{
|
||||
reinterpret_cast<ArrayElementA *>(tensor_A_compressed.device_data()), layout_A,
|
||||
reinterpret_cast<ArrayElementB *>(tensor_B.device_data()), stride_B,
|
||||
tensor_E.device_data(), layout_E,
|
||||
tensor_SFA.device_data(), layout_SFA,
|
||||
tensor_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{
|
||||
{options.alpha, options.beta},
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
// Create the arguments for host reference implementation
|
||||
auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A);
|
||||
auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA);
|
||||
auto B = make_tensor(make_iterator(tensor_B.host_data()),
|
||||
make_layout(make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(B),
|
||||
decltype(SFA),
|
||||
decltype(SFB)> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
auto C = make_tensor(make_iterator(tensor_C.host_data()),
|
||||
make_layout(make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = make_tensor(make_iterator(reference_D.host_data()),
|
||||
make_layout(make_shape(options.m, options.n, options.l), stride_D));
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementScalingFactor
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(C), // TensorC
|
||||
decltype(D) // TensorD
|
||||
> epilogue_params{};
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
auto init_pass = initialize(options);
|
||||
if (not init_pass) {
|
||||
std::cout << "Initialization failure" << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (not result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (not (props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,41 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
84a_blackwell_nvfp4_bf16_sparse_gemm
|
||||
84a_blackwell_nvfp4_bf16_sparse_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
84b_blackwell_mixed_mxfp8_bf16_sparse_gemm
|
||||
84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu
|
||||
)
|
||||
endif()
|
||||
@@ -158,7 +158,11 @@ foreach(EXAMPLE
|
||||
77_blackwell_fmha
|
||||
78_blackwell_emulated_bf16x9_gemm
|
||||
79_blackwell_geforce_gemm
|
||||
80_blackwell_geforce_sparse_gemm
|
||||
81_blackwell_gemm_blockwise
|
||||
82_blackwell_distributed_gemm
|
||||
83_blackwell_sparse_gemm
|
||||
84_blackwell_narrow_precision_sparse_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@@ -286,6 +286,18 @@
|
||||
|
||||
Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores
|
||||
|
||||
* [80_blackwell_geforce_sparse_gemm](80_blackwell_geforce_sparse_gemm/)
|
||||
|
||||
Blackwell SM120 sparse MMA kernel targeting GeForce RTX 50 series CUDA Cores
|
||||
|
||||
* [83_blackwell_sparse_gemm](83_blackwell_sparse_gemm)
|
||||
|
||||
Blackwell SM100 Sparse Gemm kernel
|
||||
|
||||
* [84_blackwell_narrow_precision_sparse_gemm](84_blackwell_narrow_precision_sparse_gemm)
|
||||
|
||||
Blackwell Block Scaled SM100 Sparse Gemm kernel
|
||||
|
||||
# CuTe - Programming Examples
|
||||
|
||||
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).
|
||||
|
||||
@@ -44,6 +44,11 @@
|
||||
#include <cuda/atomic>
|
||||
#include <cuda/std/atomic>
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@@ -115,4 +120,46 @@ struct DistGpuTimer {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Generic device-to-device data movement kernel based for CuTe tensors.
|
||||
///
|
||||
/// NOTE: this kernel assigns one element copy to every thread, and is by no means
|
||||
/// an efficient way of copying tensors. It should only be used for convenience in
|
||||
/// reference checks.
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream);
|
||||
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
__global__ void device_copy_kernel(TensorSource const tensor_source,
|
||||
TensorDestination tensor_destination) {
|
||||
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
using ElementSrc = typename TensorSource::value_type;
|
||||
using ElementDst = typename TensorDestination::value_type;
|
||||
NumericConverter<ElementDst, ElementSrc> converter;
|
||||
if (linear_idx < size(tensor_source)) {
|
||||
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream) {
|
||||
|
||||
assert(tensor_source.size() == tensor_destination.size());
|
||||
|
||||
auto numel = tensor_source.size();
|
||||
static constexpr int NumThreads = 128;
|
||||
auto grid_size = cute::ceil_div(numel, NumThreads);
|
||||
|
||||
dim3 grid(grid_size);
|
||||
dim3 block(NumThreads);
|
||||
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
||||
@@ -340,7 +340,7 @@ public:
|
||||
base_args.epilogue.thread,
|
||||
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
|
||||
tensor_c_iter.stride(),
|
||||
reinterpret_cast<const ElementD*>(tensor_d_iter.data()),
|
||||
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
|
||||
tensor_d_iter.stride()
|
||||
};
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ struct DistributedGemmKernelWrapper<
|
||||
using BaseArguments = typename BaseKernel::Arguments;
|
||||
using BaseParams = typename BaseKernel::Params;
|
||||
|
||||
static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now.");
|
||||
//static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now.");
|
||||
static_assert(not cute::is_same_v<typename BaseKernel::ElementC, void>, "DistributedGEMM epilogues must have a source.");
|
||||
|
||||
using ElementFlag = uint32_t;
|
||||
|
||||
@@ -189,6 +189,100 @@ template<
|
||||
bool Is2sm = false
|
||||
>
|
||||
constexpr bool sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(){
|
||||
// * 1SM Dense
|
||||
// * A_K(t) : TileShape_K % 128 == 0
|
||||
// * A_M(n) : TileShape_M % 128 == 0
|
||||
// * B_N(t) : TileSize_N % 128 == 0
|
||||
// * B_K(n) : TileSize_K % 128 == 0
|
||||
//
|
||||
// * 2SM Dense
|
||||
// * A_K(t) : TileShape_K % 128 == 0
|
||||
// * A_M(n) : TileShape_M % 128 == 0
|
||||
// * B_N(t) : TileSize_N % 256 == 0
|
||||
// each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned.
|
||||
// full tile_n needs to be 256 elts aligned
|
||||
// * B_K(n) : TileShape_K % 128 == 0
|
||||
//
|
||||
// * 1SM Sparse
|
||||
// * A_K(t) : TileShape_K % 256 == 0
|
||||
// num of physical elems needs to be 128 elts aligned
|
||||
// num of logical elems needs to be 256 elts aligned
|
||||
// * A_M(n) : TileShape_M % 128 == 0
|
||||
// * B_N(t) : TileSize_N % 128 == 0
|
||||
// * B_K(n) : TileSize_K % 128 == 0
|
||||
//
|
||||
// * 2SM Sparse
|
||||
// * A_K(t) : TileShape_K % 256 == 0
|
||||
// num of physical elems needs to be 128 elts aligned
|
||||
// num of logical elems needs to be 256 elts aligned
|
||||
// * A_M(n) : TileShape_M % 128 == 0
|
||||
// * B_N(t) : TileSize_N % 256 == 0
|
||||
// each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned.
|
||||
// full tile_n needs to be 256 elts aligned
|
||||
// * B_K(n) : TileShape_K % 128 == 0
|
||||
//
|
||||
// * Valid TileShape_MNK Dense
|
||||
// * Notation:
|
||||
// mma_instruction_tile_shape-cta_tile_shape
|
||||
// * s128x128x64
|
||||
// s128x128x32_128x128x128_nn YES
|
||||
// s128x128x32_128x128x128_nt YES
|
||||
// s128x128x32_128x128x128_tn YES
|
||||
// s128x128x32_128x128x128_tt YES
|
||||
// * s128x256x64
|
||||
// s128x256x32_128x256x128_nn YES
|
||||
// s128x256x32_128x256x128_nt YES
|
||||
// s128x256x32_128x256x128_tn YES
|
||||
// s128x256x32_128x256x128_tt YES
|
||||
// * s256x128x64
|
||||
// s256x128x32_256x128x128_nn YES
|
||||
// s256x128x32_256x128x128_nt NO (2SM B_N TileSize_N % 256 != 0)
|
||||
// s256x128x32_256x128x128_tn YES
|
||||
// s256x128x32_256x128x128_tt NO (2SM B_N TileSize_N % 256 != 0)
|
||||
// * s256x256x64
|
||||
// s256x256x32_256x256x128_nn YES
|
||||
// s256x256x32_256x256x128_nt YES
|
||||
// s256x256x32_256x256x128_tn YES
|
||||
// s256x256x32_256x256x128_tt YES
|
||||
//
|
||||
// * Valid TileShape_MNK Sparse
|
||||
// * s128x128x64
|
||||
// s128x128x64_128x128x128_nn YES
|
||||
// s128x128x64_128x128x128_nt YES
|
||||
// s128x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0)
|
||||
// s128x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0)
|
||||
// s128x128x64_128x128x256_nn YES
|
||||
// s128x128x64_128x128x256_nt YES
|
||||
// s128x128x64_128x128x256_tn YES
|
||||
// s128x128x64_128x128x256_tt YES
|
||||
// * s128x256x64
|
||||
// s128x256x64_128x256x128_nn YES
|
||||
// s128x256x64_128x256x128_nt YES
|
||||
// s128x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0)
|
||||
// s128x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0)
|
||||
// s128x256x64_128x256x256_nn YES
|
||||
// s128x256x64_128x256x256_nt YES
|
||||
// s128x256x64_128x256x256_tn YES
|
||||
// s128x256x64_128x256x256_tt YES
|
||||
// * s256x128x64
|
||||
// s256x128x64_128x128x128_nn YES
|
||||
// s256x128x64_128x128x128_nt NO (2SM B_N TileSize_N % 256 != 0)
|
||||
// s256x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0)
|
||||
// s256x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0)
|
||||
// s256x128x64_128x128x256_nn YES
|
||||
// s256x128x64_128x128x256_nt NO (2SM B_N TileSize_N % 256 != 0)
|
||||
// s256x128x64_128x128x256_tn YES
|
||||
// s256x128x64_128x128x256_tt NO (2SM B_N TileSize_N % 256 != 0)
|
||||
// * s256x256x64
|
||||
// s256x256x64_128x256x128_nn YES
|
||||
// s256x256x64_128x256x128_nt YES
|
||||
// s256x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0)
|
||||
// s256x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0)
|
||||
// s256x256x64_128x256x256_nn YES
|
||||
// s256x256x64_128x256x256_nt YES
|
||||
// s256x256x64_128x256x256_tn YES
|
||||
// s256x256x64_128x256x256_tt YES
|
||||
|
||||
[[maybe_unused]] constexpr int TileShape_M = Is2sm ? size<0>(TileShape_MNK{}) / 2 : size<0>(TileShape_MNK{});
|
||||
[[maybe_unused]] constexpr int TileShape_N = size<1>(TileShape_MNK{});
|
||||
[[maybe_unused]] constexpr int TileShape_K = size<2>(TileShape_MNK{});
|
||||
|
||||
@@ -432,6 +432,10 @@ public:
|
||||
init_M = get<0>(problem_shape_MNK);
|
||||
init_N = get<1>(problem_shape_MNK);
|
||||
init_K = get<2>(problem_shape_MNK);
|
||||
if constexpr (SwapAB) {
|
||||
init_M = get<1>(problem_shape_MNK);
|
||||
init_N = get<0>(problem_shape_MNK);
|
||||
}
|
||||
|
||||
if constexpr (not SwapAB) {
|
||||
dA = args.dA;
|
||||
@@ -491,7 +495,7 @@ public:
|
||||
: args_setup(args.ptr_A, args.ptr_B);
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
auto scale_k = 1;
|
||||
auto scale_k = ceil_div(init_K, args.chunk_size);
|
||||
ElementScale const* ptr_S = reinterpret_cast<ElementScale const*>(args.ptr_S);
|
||||
StrideScale dS{};
|
||||
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS));
|
||||
@@ -595,7 +599,7 @@ public:
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
const int scale_mn = SwapAB ? N : M;
|
||||
const int scale_k = (K + args.chunk_size - 1) / args.chunk_size;
|
||||
const int scale_k = ceil_div(K, args.chunk_size);
|
||||
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
|
||||
implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0));
|
||||
@@ -659,14 +663,15 @@ public:
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
const int scale_mn = SwapAB ? N : M;
|
||||
auto scale_k = mainloop_params.scale_k;
|
||||
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l)
|
||||
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(scale_mn,scale_k,L));
|
||||
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl);
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l)
|
||||
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(scale_mn,scale_k,L));
|
||||
Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl);
|
||||
}
|
||||
@@ -1217,8 +1222,8 @@ public:
|
||||
Params const& mainloop_params,
|
||||
int32_t next_group,
|
||||
ProblemShape_MNKL problem_shape_mnkl) {
|
||||
const uint32_t M = get<0>(problem_shape_mnkl);
|
||||
const uint32_t N = get<1>(problem_shape_mnkl);
|
||||
const uint32_t M = (SwapAB? get<1>(problem_shape_mnkl) : get<0>(problem_shape_mnkl));
|
||||
const uint32_t N = (SwapAB? get<0>(problem_shape_mnkl) : get<1>(problem_shape_mnkl));
|
||||
const uint32_t K = get<2>(problem_shape_mnkl);
|
||||
|
||||
// Replace all dims for consistency
|
||||
@@ -1245,14 +1250,14 @@ public:
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
NonVoidElementScale const* ptr_S = nullptr;
|
||||
auto scale_k = 1;
|
||||
auto scale_k = ceil_div(K, mainloop_params.chunk_size);
|
||||
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]);
|
||||
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale,
|
||||
prob_shape_scale, prob_stride_scale);
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
ElementZero const* ptr_Z = nullptr;
|
||||
auto scale_k = 1;
|
||||
auto scale_k = ceil_div(K, mainloop_params.chunk_size);
|
||||
Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]);
|
||||
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero,
|
||||
prob_shape_zero, prob_stride_zero);
|
||||
|
||||
@@ -426,7 +426,7 @@ public:
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
auto scale_k = (K + args.group_size - 1) / args.group_size;
|
||||
auto scale_k = ceil_div(K, args.group_size);
|
||||
ElementScale const* ptr_S = args.ptr_S;
|
||||
StrideScale dS = args.dS;
|
||||
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS));
|
||||
@@ -483,7 +483,7 @@ public:
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
const int scale_mn = SwapAB ? N : M;
|
||||
const int scale_k = (K + args.group_size - 1) / args.group_size;
|
||||
const int scale_k = ceil_div(K, args.group_size);
|
||||
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
|
||||
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
|
||||
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
|
||||
|
||||
@@ -622,6 +622,11 @@ public:
|
||||
impl_.producer_acquire(state, barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
|
||||
impl_.producer_expect_transaction(state, transaction_bytes);
|
||||
}
|
||||
|
||||
// NOP for TMA based mainloop
|
||||
CUTLASS_DEVICE
|
||||
void producer_commit(PipelineState state, uint32_t bytes) {
|
||||
|
||||
@@ -452,6 +452,11 @@ public:
|
||||
return producer_get_barrier(state.index());
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
|
||||
producer_expect_transaction(state.index(), transaction_bytes);
|
||||
}
|
||||
|
||||
////////////////////
|
||||
// Consumer APIs
|
||||
////////////////////
|
||||
@@ -519,6 +524,14 @@ private:
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_expect_transaction(uint32_t stage, uint32_t transaction_bytes) {
|
||||
detail::pipeline_check_is_producer(params_.role);
|
||||
if (params_.is_leader) {
|
||||
full_barrier_ptr_[stage].expect_transaction(transaction_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// NOP for TMA based mainloop
|
||||
CUTLASS_DEVICE
|
||||
void producer_commit(uint32_t stage, uint32_t bytes) {
|
||||
|
||||
@@ -9,15 +9,15 @@ efficient SM100 GEMM kernels targeting these new mma instructions.
|
||||
|
||||
Blackwell SM100 has 7 new `tcgen05.mma` instructions. These instructions are 2x to 4x faster then Hopper Architecture's WGMMA instructions.
|
||||
|
||||
| Ptx Instruction | Throughput | Notes |
|
||||
|----------------------------------------------------------------------------------|----------------------------|-------|
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts |
|
||||
|tcgen05.mma.cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts |
|
||||
| Ptx Instruction | Throughput | Notes |
|
||||
|---------------------------------------------------------------------------------------|----------------------------|-------|
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts |
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts |
|
||||
|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts |
|
||||
|
||||
For more detailed information see [`tcgen05.mma` PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensorcore-5th-generation-family-instructions).
|
||||
|
||||
@@ -27,7 +27,7 @@ For more detailed information see [`tcgen05.mma` PTX documentation](https://docs
|
||||
|
||||
Instructions with `kind` modifiers `mxf8f6f4`, `mxf4`, and `nvf4mxf4` perform matrix multiplication operations with scale
|
||||
factors of the form $D = C +( A \times SFA) * (B \times SFB)$. Scale factors are applied to GEMM-K dimension such that
|
||||
every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor. For example, an $M\times K$,
|
||||
every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor (32 or 64 elements for sparse as sparse gemm compress 2x along k-dim). For example, an $M\times K$,
|
||||
$A$ matrix has an associated $M \times \lceil K/32 \rceil$ SFA matrix; and an $N\times K$ $B$, matrix has an associated
|
||||
$N \times \lceil K/32 \rceil$ SFB matrix. For block scaled GEMMs, an entry of output D matrix is
|
||||
$D_{ij} = C_{ij} + \sum_{k} (A_{i,k} \times SFA_{i,k/SV}) \times (B_{j,k}\times SFB_{j,k/SV})$, in index notation, we SV is the scale factor vector size (16 or 32).
|
||||
@@ -57,12 +57,12 @@ See [PTX documentation for narrow precision data types](https://docs.nvidia.com/
|
||||
Block scaled MMAs use `mx` and `nv` types which are a pair of float8_t, float6_t, float4_t with 2 of the scale factor data types with a predetermined scale factor vector size. `mx` types follow OCP specification (see [OCP Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)). The following types provided by CUTLASS can be used as inputs to collective builders to generate the block scaled kernels:
|
||||
|
||||
**Blackwell Block Scaled Narrow Precision Data Types**
|
||||
| Mx/Nv Data Type |Scale Factor Type | SF Vector Size | OCP Compliant |
|
||||
|----------------------------|------------------|----------------|---------------|
|
||||
| mx_float8_t\<Any F8type\> |float_ue8m0_t |32 | Yes |
|
||||
| mx_float6_t\<Any F6Type\> |float_ue8m0_t |32 | Yes |
|
||||
| mx_float4_t |float_ue8m0_t |32 | Yes |
|
||||
| nv_float4_t |float_ue4m3_t |16 | No |
|
||||
| Mx/Nv Data Type |Scale Factor Type | SF Vector Size (Dense) | SF Vector Size (Sparse)| OCP Compliant |
|
||||
|----------------------------|------------------|------------------------|------------------------|---------------|
|
||||
| mx_float8_t\<Any F8type\> |float_ue8m0_t |32 |64 | Yes |
|
||||
| mx_float6_t\<Any F6Type\> |float_ue8m0_t |32 |64 | Yes |
|
||||
| mx_float4_t |float_ue8m0_t |32 |64 | Yes |
|
||||
| nv_float4_t |float_ue4m3_t |16 |32 | No |
|
||||
|
||||
## Layouts, Tensor Alignment Requirements to Target `tcgen05.mma` Instructions
|
||||
|
||||
@@ -74,13 +74,18 @@ For legacy types (`tf32`, `f16`, `bf16`, `i8` and `u8`) alignment requirements f
|
||||
All four layouts (TT, NN, NT, TT) are supported for all legacy data types.
|
||||
|
||||
**Table 1: Valid Data Type, Alignment, and Layout Combinations For MMAs with Legacy Types** <a id="legacy_gemm_table" name="legacy_gemm_table"></a>
|
||||
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|
||||
|-------------------------------|------------|------------|----------------|-------------|-------------|-------------------------|-----------|
|
||||
|1 | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | |
|
||||
|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|
||||
|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|
||||
|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)|
|
||||
|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)|
|
||||
| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|
||||
|-------------------------------|----------------|------------|------------|----------------|------------------|-------------|-------------------------|---------- |
|
||||
|[1](#legacy_rows) | Dense | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | |
|
||||
|[2](#legacy_rows) | Dense | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu) |
|
||||
|[3](#legacy_rows) | Dense | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)|
|
||||
|[4](#legacy_rows) | Dense | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu) |
|
||||
|[5](#legacy_rows) | Dense | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu) |
|
||||
|[6](#legacy_rows) | Sparse | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 (N) / 8 (T) | 4 | tf32 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f32_f32_f32_f32_f32_tfmma.cu) |
|
||||
|[7](#legacy_rows) | Sparse | half_t | half_t | TN, NN, NT, TT | 8 (N) / 16 (T) | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f16_f16_f32_f16_f16_hmma.cu) |
|
||||
|[8](#legacy_rows) | Sparse | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 (N) / 16 (T) | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f16_f16_f32_f16_f16_hmma.cu) |
|
||||
|[9](#legacy_rows) | Sparse | int8_t | int8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_s8_s8_s32_s8_s8_imma.cu) |
|
||||
|[10](#legacy_rows) | Sparse | uint8_t | uint8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_s8_s8_s32_s8_s8_imma.cu) |
|
||||
|
||||
For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every `tcgen05.mma` instructions.
|
||||
Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision
|
||||
@@ -91,203 +96,298 @@ Below tables list valid layout, and alignment values for each A and B data type
|
||||
instructions supported by CUTLASS.
|
||||
|
||||
**Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling** <a id="non_bs_gemm_table" name="non_bs_gemm_table"></a>
|
||||
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|
||||
|-------------------------------|----------|----------|----------------|-------------|-------------|-------------------------|-----------|
|
||||
|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|
||||
|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|
||||
|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|
||||
|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|
||||
|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)|
|
||||
| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test |
|
||||
|-------------------------------|----------------|----------|----------|----------------|-------------------|-------------|-------------------------|-----------|
|
||||
|[1](#nonbs_rows_1_2_3_6) | Dense | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[2](#nonbs_rows_1_2_3_6) | Dense | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[3](#nonbs_rows_1_2_3_6) | Dense | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[4](#nonbs_rows_4_7) | Dense | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|
||||
|[5](#nonbs_rows_5_8) | Dense | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|
||||
|[6](#nonbs_rows_1_2_3_6) | Dense | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu) <br> [NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu) <br> [TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) |
|
||||
|[7](#nonbs_rows_4_7) | Dense | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) |
|
||||
|[8](#nonbs_rows_5_8) | Dense | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu) <br> [NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) |
|
||||
|[9](#nonbs_rows_9) | Dense | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)|
|
||||
|[10](#nonbs_rows_1_2_3_6) | Sparse | float4_t | float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu) |
|
||||
|[11](#nonbs_rows_1_2_3_6) | Sparse | float4_t | float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu) |
|
||||
|[12](#nonbs_rows_1_2_3_6) | Sparse | float6_t | float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu) |
|
||||
|[13](#nonbs_rows_4_7) | Sparse | float4_t | float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu) |
|
||||
|[14](#nonbs_rows_5_8) | Sparse | float8_t | float4_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu) |
|
||||
|[15](#nonbs_rows_1_2_3_6) | Sparse | float6_t | float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu) |
|
||||
|[16](#nonbs_rows_4_7) | Sparse | float6_t | float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu) |
|
||||
|[17](#nonbs_rows_5_8) | Sparse | float8_t | float6_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu) |
|
||||
|[18](#nonbs_rows_9) | Sparse | float8_t | float8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f8_f8_f32_f16_f16_qmma.cu) |
|
||||
|
||||
|
||||
**Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs** <a id="bs_gemm_table" name="bs_gemm_table"></a>
|
||||
| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test|
|
||||
|-------------------------|-------------|-------------|----------------|-------------|-------------|-------------------------|------|
|
||||
|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)|
|
||||
|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)|
|
||||
|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)|
|
||||
|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)|
|
||||
|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)|
|
||||
|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)|
|
||||
|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)|
|
||||
|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)|
|
||||
|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)|
|
||||
|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)|
|
||||
|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)|
|
||||
| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test|
|
||||
|--------------------------|----------------|-------------|-------------|----------------|-------------------|-------------|-------------------------|---------|
|
||||
|[1](#bs_rows_1) | Dense | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)|
|
||||
|[2](#bs_rows_2) | Dense | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)|
|
||||
|[3](#bs_rows_3) | Dense | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)|
|
||||
|[4](#bs_rows_4_5_7_8_10) | Dense | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)|
|
||||
|[5](#bs_rows_4_5_7_8_10) | Dense | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)|
|
||||
|[6](#bs_rows_6_9_11) | Dense | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)|
|
||||
|[7](#bs_rows_4_5_7_8_10) | Dense | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)|
|
||||
|[8](#bs_rows_4_5_7_8_10) | Dense | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)|
|
||||
|[9](#bs_rows_6_9_11) | Dense | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)|
|
||||
|[10](#bs_rows_4_5_7_8_10) | Dense | mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)|
|
||||
|[11](#bs_rows_6_9_11) | Dense | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)<br>[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)|
|
||||
|[12](#bs_rows_1) | Sparse | nv_float4_t | nv_float4_t | TN | 32 (N) / 64 (T) | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_nvf4_nvf4_f32_void_f16_o_tnn.cu) |
|
||||
|[13](#bs_rows_2) | Sparse | mx_float4_t | mx_float4_t | TN | 32 (N) / 64 (T) | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf4_f32_f16_f16_o_tnn.cu) |
|
||||
|[14](#bs_rows_3) | Sparse | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf4_f32_f16_f16_q_tnt.cu) |
|
||||
|[15](#bs_rows_4_5_7_8_10) | Sparse | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu) |
|
||||
|[16](#bs_rows_4_5_7_8_10) | Sparse | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu) |
|
||||
|[17](#bs_rows_6_9_11) | Sparse | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu) |
|
||||
|[18](#bs_rows_4_5_7_8_10) | Sparse | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu) |
|
||||
|[19](#bs_rows_4_5_7_8_10) | Sparse | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu) |
|
||||
|[20](#bs_rows_6_9_11) | Sparse | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu) |
|
||||
|[21](#bs_rows_4_5_7_8_10) | Sparse | mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu) |
|
||||
|[22](#bs_rows_6_9_11) | Sparse | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf8_f32_f16_f16_q_tnn.cu) |
|
||||
|
||||
## MMA tile shapes supported
|
||||
|
||||
The alignment restrictions also limit the options for Mma Tile Shapes. Tables below list the supported/valid `MmaTileShape`,
|
||||
Layout, and Dispatch Policy combinations for each row of [Table 1](#legacy_gemm_table), [Table 2](#non_bs_gemm_table), and [Table 3](#bs_gemm_table).
|
||||
|
||||
**Table 4: Valid Tile Shapes and Dispatch Policies for lagacy types (All rows of Table 1)** <a id="legacy_rows" name="legacy_rows"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|------------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
**Table 4: Valid Tile Shapes and Dispatch Policies for legacy types (All rows of Table 1)** <a id="legacy_rows" name="legacy_rows"></a>
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|--------------------|----|----|----|----|------------------------------------------|
|
||||
| Dense | 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 1SM | 128x64x(2/4*MMA-K) | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x128x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x192x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 2SM | 256x64x(2/4*MMA-K) | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x128x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x192x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6 of Table 2)** <a id="nonbs_rows_1_2_3_6" name="nonbs_rows_1_2_3_6"></a>
|
||||
**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6,10,11,12,and 15 of Table 2)** <a id="nonbs_rows_1_2_3_6" name="nonbs_rows_1_2_3_6"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
|
||||
| Dense | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 1SM | 128x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 2SM | 256x128x128 | N | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8 of Table 2)** <a id="nonbs_rows_5_8" name="nonbs_rows_5_8"></a>
|
||||
**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8,14,and 17 of Table 2)** <a id="nonbs_rows_5_8" name="nonbs_rows_5_8"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
|
||||
| Dense | 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 2SM | 256x128x128 | Y | Y | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7 of Table 2)** <a id="nonbs_rows_4_7" name="nonbs_rows_4_7"></a>
|
||||
**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7,13,and 16 of Table 2)** <a id="nonbs_rows_4_7" name="nonbs_rows_4_7"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
|
||||
| Dense | 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 1SM | 128x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 2SM | 256x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9 of Table 2)** <a id="nonbs_rows_9" name="nonbs_rows_9"></a>
|
||||
**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9,18 of Table 2)** <a id="nonbs_rows_9" name="nonbs_rows_9"></a>
|
||||
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|----------------|----|----|----|----|------------------------------------|
|
||||
| 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|------------------------------------------|
|
||||
| Dense | 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` |
|
||||
| Dense | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` |
|
||||
| Sparse | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
| Sparse | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` |
|
||||
|
||||
**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 and 12 of Table 3)** <a id="bs_rows_1" name="bs_rows_1"></a>
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|----------------------------------------------|
|
||||
| Dense | 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Dense | 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Dense | 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Dense | 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| Dense | 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| Dense | 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
|
||||
**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 of Table 3)** <a id="bs_rows_1" name="bs_rows_1"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|----------------------------------------|
|
||||
| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 and 13 of Table 3)** <a id="bs_rows_2" name="bs_rows_2"></a>
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|----------------------------------------------|
|
||||
| Dense | 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| Dense | 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| Dense | 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| Dense | 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| Dense | 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| Dense | 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` |
|
||||
|
||||
**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 of Table 3)** <a id="bs_rows_2" name="bs_rows_2"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|----------------------------------------|
|
||||
| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` |
|
||||
| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` |
|
||||
**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 and 14 of Table 3)** <a id="bs_rows_3" name="bs_rows_3"></a>
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|--------------------------------------------------|
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x192x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 of Table 3)** <a id="bs_rows_3" name="bs_rows_3"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|--------------------------------------------|
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10, 15, 16, 18, 19, and 21 of Table 3)** <a id="bs_rows_4_5_7_8_10" name="bs_rows_4_5_7_8_10"></a>
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|--------------------------------------------------|
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x192x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10 of Table 3)** <a id="bs_rows_4_5_7_8_10" name="bs_rows_4_5_7_8_10"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|--------------------------------------------|
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11 of Table 3)** <a id="bs_rows_6_9_11" name="bs_rows_6_9_11"></a>
|
||||
| 1/2 SM | Mma Tile Shape | TN| TT | NT | NN | Dispatch Policy |
|
||||
|--------|---------------|----|----|----|----|--------------------------------------------|
|
||||
| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11, 17, 20, and 22 of Table 3)** <a id="bs_rows_6_9_11" name="bs_rows_6_9_11"></a>
|
||||
| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy |
|
||||
|----------------|--------|----------------|----|----|----|----|--------------------------------------------------|
|
||||
| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` |
|
||||
|
||||
## Epilogue config supported
|
||||
|
||||
**Table 14: Epilogue Dispatch Policy** <a id="epi_dispatch" name="epi_dispatch"></a>
|
||||
| 1/2 SM | Epilogue Dispatch Policy |
|
||||
|--------|------------------------------------------|
|
||||
| 1SM | cutlass::epilogue::TmaWarpSpecialized1Sm |
|
||||
| 1SM | cutlass::epilogue::NoSmemWarpSpecialized1Sm |
|
||||
| 2SM | cutlass::epilogue::TmaWarpSpecialized2Sm |
|
||||
| 2SM | cutlass::epilogue::NoSmemWarpSpecialized2Sm |
|
||||
| Dense / Sparse | Legacy / Narrow Precision | 1/2 SM | Epilogue Dispatch Policy |
|
||||
|----------------|-----------------------------|--------|----------------------------------------------------|
|
||||
| Dense | Legacy & Narrow Precision | 1SM | `cutlass::epilogue::TmaWarpSpecialized1Sm` |
|
||||
| Dense | Legacy & Narrow Precision | 1SM | `cutlass::epilogue::NoSmemWarpSpecialized1Sm` |
|
||||
| Dense | Legacy & Narrow Precision | 2SM | `cutlass::epilogue::TmaWarpSpecialized2Sm` |
|
||||
| Dense | Legacy & Narrow Precision | 2SM | `cutlass::epilogue::NoSmemWarpSpecialized2Sm` |
|
||||
| Sparse | Legacy | 1SM | `cutlass::epilogue::TmaWarpSpecialized1Sm` |
|
||||
| Sparse | Legacy | 1SM | `cutlass::epilogue::NoSmemWarpSpecialized1Sm` |
|
||||
| Sparse | Legacy | 2SM | `cutlass::epilogue::TmaWarpSpecialized2Sm` |
|
||||
| Sparse | Legacy | 2SM | `cutlass::epilogue::NoSmemWarpSpecialized2Sm` |
|
||||
| Sparse | Narrow Precision (nvf4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmNvf4` |
|
||||
| Sparse | Narrow Precision (nvf4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmNvf4` |
|
||||
| Sparse | Narrow Precision (mxf4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmMxf4` |
|
||||
| Sparse | Narrow Precision (mxf4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmMxf4` |
|
||||
| Sparse | Narrow Precision (mxf8f6f4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4` |
|
||||
| Sparse | Narrow Precision (mxf8f6f4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4` |
|
||||
|
||||
**Table 15: Epilogue PerSmTileShape_MNK** <a id="epi_persmtileshape" name="epi_persmtileshape"></a>
|
||||
| 1/2 SM | MMA tile Shape | PerSmTileShape_MNK |
|
||||
@@ -314,14 +414,16 @@ MMA_TileShape_K is is generally 4 * MMA-Instruction-K. It depends on the config
|
||||
### Auto Kernel Dispatch Policies
|
||||
|
||||
In addition to direct dispatch policies listed above, the user can also use auto policies for both non-block scaled narrow-precision
|
||||
GEMMs, and block scaled narrow-precision GEMMs.
|
||||
GEMMs (both sparse and dense), and block scaled narrow-precision GEMMs (only dense).
|
||||
|
||||
CUTLASS will do its best to find the most efficient kernel for given parameters, however, the preferred method for building
|
||||
these kernels is to use direct kernel dispatch policies shown in the above tables.
|
||||
|
||||
* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma`.
|
||||
* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma(.sp)`.
|
||||
* `KernelTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
|
||||
* `KernelTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
|
||||
* `KernelSparseTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma.sp` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
|
||||
* `KernelSparseTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma.sp` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
|
||||
|
||||
Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueScheduleAuto`.
|
||||
|
||||
@@ -330,16 +432,23 @@ Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueSche
|
||||
For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel). An example dense GEMM can be found:
|
||||
1. [Blackwell FP16 GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/).
|
||||
|
||||
An example sparse GEMM can be found:
|
||||
1. [Blackwell FP16 Sparse GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/83_blackwell_sparse_gemm/).
|
||||
|
||||
Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface
|
||||
(as described in [CUTLASS 3.0 GEMM API](gemm_api_3x.md#cutlass-30-gemm-api)). However, special attention needs to be given to
|
||||
A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel
|
||||
which are listed above.
|
||||
|
||||
Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory:
|
||||
Several examples of block scaled dense GEMM kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory:
|
||||
1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
2. [NVF4 Gemm with block scaling and NVF4 output matrix](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
|
||||
Several examples of block scaled sparse GEMM kernels can be found in [examples/84_blackwell_narrow_precision_sparse_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm) directory:
|
||||
1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
2. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
|
||||
Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described
|
||||
[here](gemm_api_3x.md#collective-builder-for-collectivemmas) with a small difference for Collective MMA builder interface.
|
||||
As in all Blackwell kernels, the `TileShape_MNK` argument expects the `MmaTileShape_MNK` which is the tile shape needed
|
||||
|
||||
@@ -28,7 +28,6 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
@@ -59,7 +58,10 @@ std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &deviceProperti
|
||||
|
||||
int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor;
|
||||
if (deviceMajorMinor) {
|
||||
int32_t clock_MHz = deviceProperties.clockRate / 1000;
|
||||
int32_t clock_MHz;
|
||||
int32_t clock_KHz;
|
||||
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
|
||||
clock_MHz = clock_KHz / 1000;
|
||||
out << "GPU(compute_"
|
||||
<< deviceMajorMinor << ", "
|
||||
<< deviceProperties.multiProcessorCount << " SMs @ " << clock_MHz << " MHz)";
|
||||
|
||||
@@ -29,22 +29,25 @@
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -57,7 +60,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -70,7 +73,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -83,7 +86,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -96,7 +99,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -109,7 +112,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -127,7 +130,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -140,7 +143,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -148,10 +151,32 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
sm100_bssp_gemm_mxf8_mxf4_f32_f16_mxf8_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf8_mxf4_f32_f32_f32_q_tnt.cu
|
||||
|
||||
sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@@ -162,7 +187,16 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -26,18 +26,19 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_subdirectory(narrow_precision)
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm100_sp
|
||||
cutlass_test_unit_gemm_device_sm100_sparse
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_sm100_sp_general
|
||||
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
|
||||
cutlass_test_unit_gemm_device_sm100_sp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_general
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_streamk
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sp_general
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_general
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
@@ -52,23 +53,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu
|
||||
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_streamk
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_narrow_precision
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
|
||||
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
|
||||
|
||||
sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
|
||||
|
||||
sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
|
||||
|
||||
sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
@@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@@ -31,7 +31,8 @@
|
||||
/* \file
|
||||
\brief Command line options for performance test program
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <set>
|
||||
@@ -165,9 +166,11 @@ void Options::Device::print_usage(std::ostream &out) const {
|
||||
break;
|
||||
}
|
||||
else {
|
||||
int32_t clock_KHz;
|
||||
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
|
||||
out << " [" << idx << "] - "
|
||||
<< prop.name << " - SM " << prop.major << "." << prop.minor << ", "
|
||||
<< prop.multiProcessorCount << " SMs @ " << (prop.clockRate / 1000.0) << " MHz, "
|
||||
<< prop.multiProcessorCount << " SMs @ " << (clock_KHz / 1000.0) << " MHz, "
|
||||
<< "L2 cache: " << (prop.l2CacheSize >> 20) << " MB, Global Memory: " << (prop.totalGlobalMem >> 30) << " GB"
|
||||
<< std::endl;
|
||||
}
|
||||
@@ -216,9 +219,11 @@ void Options::Device::print_options(std::ostream &out, int indent) const {
|
||||
for (int device : devices) {
|
||||
out << device << ',';
|
||||
}
|
||||
int32_t clock_KHz;
|
||||
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
|
||||
out
|
||||
<< "\n"
|
||||
<< indent_str(indent) << "clock: " << int(double(properties[0].clockRate) / 1000.0) << "\n"
|
||||
<< indent_str(indent) << "clock: " << int(double(clock_KHz) / 1000.0) << "\n"
|
||||
<< indent_str(indent) << "compute-capability: " << compute_capability(0) << "\n";
|
||||
}
|
||||
|
||||
|
||||
@@ -109,7 +109,8 @@ bool BlockCompareEqual(
|
||||
Element const *ptr_B,
|
||||
size_t capacity,
|
||||
int grid_size = 0,
|
||||
int block_size = 0) {
|
||||
int block_size = 0,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
int equal_flag = 1;
|
||||
int *device_equal_flag = nullptr;
|
||||
@@ -146,7 +147,9 @@ bool BlockCompareEqual(
|
||||
dim3 grid(grid_size, 1, 1);
|
||||
dim3 block(block_size, 1, 1);
|
||||
|
||||
kernel::BlockCompareEqual<Element><<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity);
|
||||
kernel::BlockCompareEqual<Element><<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity);
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
if (cudaMemcpy(
|
||||
&equal_flag,
|
||||
@@ -175,7 +178,8 @@ bool BlockCompareRelativelyEqual(
|
||||
Element epsilon,
|
||||
Element nonzero_floor,
|
||||
int grid_size = 0,
|
||||
int block_size = 0) {
|
||||
int block_size = 0,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
int equal_flag = 1;
|
||||
int *device_equal_flag = nullptr;
|
||||
@@ -212,7 +216,7 @@ bool BlockCompareRelativelyEqual(
|
||||
dim3 grid(grid_size, 1, 1);
|
||||
dim3 block(block_size, 1, 1);
|
||||
|
||||
kernel::BlockCompareRelativelyEqual<Element><<< grid, block >>>(
|
||||
kernel::BlockCompareRelativelyEqual<Element><<< grid, block, 0, stream >>>(
|
||||
device_equal_flag,
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
@@ -221,6 +225,8 @@ bool BlockCompareRelativelyEqual(
|
||||
nonzero_floor
|
||||
);
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
if (cudaMemcpy(
|
||||
&equal_flag,
|
||||
device_equal_flag,
|
||||
|
||||
@@ -232,6 +232,8 @@ ComputeType TensorTransformReduce(
|
||||
workspace, identity, workspace_size, reduce
|
||||
);
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
if (copy_out) {
|
||||
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
|
||||
if (result != cudaSuccess) {
|
||||
@@ -285,6 +287,8 @@ ComputeType TensorTransformReduce(
|
||||
workspace, identity, workspace_size, reduce
|
||||
);
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
if (copy_out) {
|
||||
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
Reference in New Issue
Block a user