[rocm-libraries] ROCm/rocm-libraries#4837 (commit 6316035)

[CK TILE] Unification of sparse MFMA/WMMA policy structs
 (#4837)

## Motivation

The existing unification work supports DENSE intrinsics. In this PR we
enable support for SPARSE as well as SCALE intrinsics and add an example
SPARSE implementation.

## Technical Details

Mostly trivial changes. One framework change is that the desired
`MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant
commit explains, we do not support a fallback family at the moment, but
it is something we can consider.

## Test Plan

Added a new test for the relevant sparse specializations.

## Test Result

Test should pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
chris-tsiaousis-hpc
2026-03-05 19:53:16 +00:00
committed by assistant-librarian[bot]
parent 6e558658ea
commit 03ce21ddcb
23 changed files with 1173 additions and 89 deletions

View File

@@ -4,8 +4,54 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
namespace ck_tile {
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename WarpGemmAttribute_>
struct WarpGemmSmfmacImpl
{
@@ -41,37 +87,10 @@ struct WarpGemmSmfmacImpl
return WarpGemmAttribute_::get_num_of_access();
}
//----------------------------------------------------------------------------------------------
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
/// elements into lower part of a_vec to half its effective size.
///
/// @param a_vec Vector to be compressed.
///
/// @return Four 2-bit indexes of non-zero elements locations
///
template <typename AVec>
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
template <index_t CompressedSize, typename AVec>
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
{
int32_t idx = 0b11101110;
static_for<0, 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
return compress_a_impl<ADataType, CompressedSize>(a_vec);
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
@@ -84,10 +103,11 @@ struct WarpGemmSmfmacImpl
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using AVecCompressed =
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
static constexpr index_t CompressedSize =
ATensor::get_thread_buffer_size() / CompressionRatio;
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -95,8 +115,9 @@ struct WarpGemmSmfmacImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
const int32_t idx = compress_a(a_vec);
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
static_assert(CompressedSize == 4);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};