mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-28 02:57:42 +00:00
[CK TILE] fix numerical errors of preshuffle_b This pull request introduces several improvements and fixes related to quantized grouped GEMM (General Matrix Multiply) pipelines and their supporting utilities. # The numerical issue ## Steps to reproduce ```bash Run ./bin/tile_example_gemm_weight_preshuffle -prec=fp8 ./bin/tile_example_gemm_weight_preshuffle -prec=int4 ``` # Solution The main changes address type correctness, improve data layout and shuffling logic, and expand test coverage to better validate different GEMM configurations. **Key changes include:** ### Data layout and shuffling logic * Refactored the logic in `shuffle_b_permuteN` to use `constexpr` variables for `KLane` and `ItemsPerAccess`, simplifying tile view construction and correcting the permutation order for improved efficiency and correctness (`tensor_shuffle_utils.hpp`). * Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline policies to account for internal data type conversion (e.g., from `pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`, `gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`). [[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275) [[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105) ### Test infrastructure enhancements * Unit tests did not catch this issue since there were no tests for fp8. Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`) to support additional GEMM tile shapes and updated tests to run with these configurations for broader coverage (`test_gemm_pipeline_util.hpp`). [[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103) [[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
187 lines
7.1 KiB
C++
187 lines
7.1 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
#include "device_prop.hpp"
|
|
#include <stdexcept>
|
|
|
|
namespace ck_tile {
|
|
template <typename T>
|
|
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
|
{
|
|
if(t->get_lengths().size() != 2)
|
|
{
|
|
throw std::runtime_error("Host tensor is not rank 2 tensor.");
|
|
}
|
|
int m_ = t->get_lengths()[0];
|
|
int aqk_ = t->get_lengths()[1];
|
|
|
|
if(aqk_ % block_aq_k != 0)
|
|
{
|
|
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
|
|
}
|
|
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
|
|
std::copy(t->begin(), t->end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
|
}
|
|
|
|
template <typename T>
|
|
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
|
{
|
|
const auto& lengths = t->get_lengths();
|
|
const size_t rank = lengths.size();
|
|
|
|
// Validate block_bq_k divisibility based on rank
|
|
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
|
|
|
|
if(bqk_dim < 0)
|
|
{
|
|
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
|
|
std::to_string(rank));
|
|
}
|
|
|
|
if(bqk_dim % block_bq_k != 0)
|
|
{
|
|
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
|
|
}
|
|
|
|
// For TilePermuteN
|
|
if(rank == 5)
|
|
{
|
|
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
|
|
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
|
|
static_cast<int>(lengths[1]),
|
|
static_cast<int>(lengths[2]),
|
|
static_cast<int>(lengths[3]),
|
|
bqk_dim / block_bq_k,
|
|
block_bq_k});
|
|
std::copy(t->begin(), t->end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
|
|
}
|
|
else // rank == 2
|
|
{
|
|
// Handle 2D tensor: [bqk, n]
|
|
int n_ = lengths[1];
|
|
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
|
|
std::copy(t->begin(), t->end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
|
}
|
|
}
|
|
|
|
template <typename GemmConfig, typename T>
|
|
auto shuffle_b(const ck_tile::HostTensor<T>& t, GemmConfig)
|
|
{
|
|
assert(t.get_lengths().size() == 2);
|
|
int n_ = t.get_lengths()[1];
|
|
int k_ = t.get_lengths()[0];
|
|
|
|
if(ck_tile::is_gfx12_supported())
|
|
{
|
|
constexpr int divisor = 2;
|
|
constexpr int kABK1PerLane = 8;
|
|
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
|
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
|
GemmConfig::N_Warp_Tile,
|
|
k_ / GemmConfig::K_Warp_Tile,
|
|
kABK0PerLane,
|
|
divisor,
|
|
kABK1PerLane});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
|
}
|
|
else if(ck_tile::is_gfx11_supported())
|
|
{
|
|
int divisor = 1;
|
|
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
|
GemmConfig::N_Warp_Tile,
|
|
k_ / GemmConfig::K_Warp_Tile,
|
|
divisor,
|
|
GemmConfig::K_Warp_Tile / divisor});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
|
}
|
|
else
|
|
{
|
|
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
|
|
constexpr int ItemsPerAccess =
|
|
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
|
|
|
|
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
|
GemmConfig::N_Warp_Tile,
|
|
k_ / ItemsPerAccess,
|
|
ItemsPerAccess});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
|
}
|
|
}
|
|
|
|
template <typename GemmConfig, typename T>
|
|
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
|
{
|
|
return shuffle_b(t, GemmConfig{});
|
|
}
|
|
|
|
template <typename GemmConfig, typename T>
|
|
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
|
{
|
|
assert(t.get_lengths().size() == 2);
|
|
|
|
int n_ = t.get_lengths()[1];
|
|
int bqk_ = t.get_lengths()[0];
|
|
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
|
|
|
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
|
|
GemmConfig::N_Warp,
|
|
GemmConfig::N_Warp_Tile / group_n,
|
|
NRepeat,
|
|
bqk_});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
|
|
}
|
|
|
|
template <typename GemmConfig, typename T>
|
|
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
|
{
|
|
assert(t.get_lengths().size() == 2);
|
|
int n_ = t.get_lengths()[1];
|
|
int k_ = t.get_lengths()[0];
|
|
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
|
|
if(ck_tile::is_gfx12_supported())
|
|
{
|
|
constexpr int divisor = 2;
|
|
constexpr int kABK1PerLane = 8;
|
|
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
|
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
|
gemmConfig.N_Warp,
|
|
gemmConfig.N_Warp_Tile,
|
|
NRepeat,
|
|
k_ / gemmConfig.K_Warp_Tile,
|
|
kABK0PerLane,
|
|
divisor,
|
|
kABK1PerLane});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
|
|
}
|
|
else
|
|
{
|
|
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
|
|
constexpr int ItemsPerAccess =
|
|
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
|
|
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
|
gemmConfig.N_Warp,
|
|
gemmConfig.N_Warp_Tile,
|
|
NRepeat,
|
|
k_ / ItemsPerAccess,
|
|
ItemsPerAccess});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
|
|
}
|
|
}
|
|
|
|
template <typename GemmConfig, typename T>
|
|
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
|
{
|
|
return shuffle_b_permuteN(t, GemmConfig{});
|
|
}
|
|
} // namespace ck_tile
|