mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4837 (commit 6316035)
[CK TILE] Unification of sparse MFMA/WMMA policy structs (#4837) ## Motivation The existing unification work supports DENSE intrinsics. In this PR we enable support for SPARSE as well as SCALE intrinsics and add an example SPARSE implementation. ## Technical Details Mostly trivial changes. One framework change is that the desired `MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant commit explains, we do not support a fallback family at the moment, but it is something we can consider. ## Test Plan Added a new test for the relevant sparse specializations. ## Test Result Test should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6e558658ea
commit
03ce21ddcb
@@ -4,8 +4,54 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
* elements into lower part of a_vec to half its effective size.
|
||||
* @param a_vec Vector to be compressed.
|
||||
* @tparam ADataType The data type of a_vec
|
||||
* @tparam CompressedSize The target compression size
|
||||
* @tparam AVec The vector type of a_vec (deduced)
|
||||
* @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields.
|
||||
* Each field encodes the original position (0–3) of the corresponding
|
||||
* non‑zero element in the input. If fewer than CompressedSize
|
||||
* non‑zeros are found, remaining fields default to 2 (see below).
|
||||
*/
|
||||
template <typename ADataType, index_t CompressedSize, typename AVec>
|
||||
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
|
||||
{
|
||||
// idx holds one 2‑bit index per output element (total CompressedSize entries).
|
||||
// It is initialized to the pattern 0b10 for every field. This matches
|
||||
// what the hardware expects when there are fewer than two non‑zero values
|
||||
// in a 4‑element group – the unused output is treated as coming from slot 2.
|
||||
// The loop below will clear and set each field as real non‑zeros are seen.
|
||||
int32_t idx = 0;
|
||||
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
|
||||
|
||||
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
// clear the two‑bit field for this output and insert j
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmSmfmacImpl
|
||||
{
|
||||
@@ -41,37 +87,10 @@ struct WarpGemmSmfmacImpl
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
/// elements into lower part of a_vec to half its effective size.
|
||||
///
|
||||
/// @param a_vec Vector to be compressed.
|
||||
///
|
||||
/// @return Four 2-bit indexes of non-zero elements locations
|
||||
///
|
||||
template <typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
|
||||
template <index_t CompressedSize, typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
|
||||
{
|
||||
int32_t idx = 0b11101110;
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
return compress_a_impl<ADataType, CompressedSize>(a_vec);
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
@@ -84,10 +103,11 @@ struct WarpGemmSmfmacImpl
|
||||
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using AVecCompressed =
|
||||
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
static constexpr index_t CompressedSize =
|
||||
ATensor::get_thread_buffer_size() / CompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
@@ -95,8 +115,9 @@ struct WarpGemmSmfmacImpl
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
const int32_t idx = compress_a(a_vec);
|
||||
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
|
||||
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user