Files
composable_kernel/include/ck_tile/host/tensor_shuffle_utils.hpp
Cong Ma d06f35027a [rocm-libraries] ROCm/rocm-libraries#4354 (commit d41f08a)
[CK TILE] fix numerical errors of preshuffle_b

This pull request introduces several improvements and fixes related to
quantized grouped GEMM (General Matrix Multiply) pipelines and their
supporting utilities.

# The numerical issue

## Steps to reproduce
```bash
Run
./bin/tile_example_gemm_weight_preshuffle -prec=fp8
./bin/tile_example_gemm_weight_preshuffle -prec=int4
```

# Solution
The main changes address type correctness, improve data layout and
shuffling logic, and expand test coverage to better validate different
GEMM configurations.

**Key changes include:**

### Data layout and shuffling logic

* Refactored the logic in `shuffle_b_permuteN` to use `constexpr`
variables for `KLane` and `ItemsPerAccess`, simplifying tile view
construction and correcting the permutation order for improved
efficiency and correctness (`tensor_shuffle_utils.hpp`).
* Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline
policies to account for internal data type conversion (e.g., from
`pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in
quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`,
`gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`).
[[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275)
[[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105)

### Test infrastructure enhancements

* Unit tests did not catch this issue since there were no tests for fp8.
Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`)
to support additional GEMM tile shapes and updated tests to run with
these configurations for broader coverage
(`test_gemm_pipeline_util.hpp`).
[[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103)
[[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269)

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
2026-02-11 07:05:46 +00:00

187 lines
7.1 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "device_prop.hpp"
#include <stdexcept>
namespace ck_tile {
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
{
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename T>
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
{
const auto& lengths = t->get_lengths();
const size_t rank = lengths.size();
// Validate block_bq_k divisibility based on rank
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
if(bqk_dim < 0)
{
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
std::to_string(rank));
}
if(bqk_dim % block_bq_k != 0)
{
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
}
// For TilePermuteN
if(rank == 5)
{
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
static_cast<int>(lengths[1]),
static_cast<int>(lengths[2]),
static_cast<int>(lengths[3]),
bqk_dim / block_bq_k,
block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
}
else // rank == 2
{
// Handle 2D tensor: [bqk, n]
int n_ = lengths[1];
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t, GemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else if(ck_tile::is_gfx11_supported())
{
int divisor = 1;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else
{
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess =
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
return shuffle_b(t, GemmConfig{});
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
}
else
{
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess =
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
return shuffle_b_permuteN(t, GemmConfig{});
}
} // namespace ck_tile