[rocm-libraries] ROCm/rocm-libraries#4280 (commit b7de1e1)

[CK_TILE] Add blockscale GEMM support for EightWarps on
 gfx950 (#4280)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

gemm blockscale eightwarps support

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [x] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
kensclin
2026-02-09 03:55:52 +00:00
committed by assistant-librarian[bot]
parent 731afe535a
commit 5b3e527c88
19 changed files with 1881 additions and 225 deletions

View File

@@ -591,9 +591,7 @@ struct QuantGemmKernel
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
else if constexpr(kQuantType == QuantType::AQuantGrouped && !APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
@@ -610,6 +608,29 @@ struct QuantGemmKernel
aq_ptr,
make_tuple(kargs.QK_A, kargs.M),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else // Column major AQ
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(1, kargs.stride_AQ),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
@@ -647,19 +668,12 @@ struct QuantGemmKernel
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
else if constexpr(kQuantType == QuantType::AQuantGrouped && !APreshuffleQuant)
{
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_tensor_view,
@@ -673,6 +687,16 @@ struct QuantGemmKernel
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_tensor_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_tensor_view,