mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
enable preshuffle quant with permuteN
This commit is contained in:
@@ -246,6 +246,8 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill : public GemmConfigBase
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -624,9 +624,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(*bq_tensor_ptr);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr);
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_permuted_host, GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
|
||||
}
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
|
||||
@@ -24,20 +24,43 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
||||
template <typename T>
|
||||
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
{
|
||||
if(t->get_lengths().size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Host tensor is not rank 2 tensor.");
|
||||
}
|
||||
int bqk_ = t->get_lengths()[0];
|
||||
int n_ = t->get_lengths()[1];
|
||||
const auto& lengths = t->get_lengths();
|
||||
const size_t rank = lengths.size();
|
||||
|
||||
if(bqk_ % block_bq_k != 0)
|
||||
// Validate block_bq_k divisibility based on rank
|
||||
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
|
||||
|
||||
if(bqk_dim < 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_bq needs a bqk of multiple times of block_bq_k.");
|
||||
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
|
||||
std::to_string(rank));
|
||||
}
|
||||
|
||||
if(bqk_dim % block_bq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
|
||||
}
|
||||
|
||||
if(rank == 5)
|
||||
{
|
||||
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
|
||||
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
|
||||
static_cast<int>(lengths[1]),
|
||||
static_cast<int>(lengths[2]),
|
||||
static_cast<int>(lengths[3]),
|
||||
bqk_dim / block_bq_k,
|
||||
block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
|
||||
}
|
||||
else // rank == 2
|
||||
{
|
||||
// Handle 2D tensor: [bqk, n]
|
||||
int n_ = lengths[1];
|
||||
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_, bqk_ / block_bq_k, block_bq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
@@ -57,7 +80,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
|
||||
@@ -431,7 +431,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
|
||||
Reference in New Issue
Block a user