enable preshuffle quant with permuteN

This commit is contained in:
khuagarw
2025-11-12 21:32:02 +00:00
parent 0f79fa5aed
commit b8b5709ccf
4 changed files with 51 additions and 16 deletions

View File

@@ -246,6 +246,8 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill : public GemmConfigBase
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool PreshuffleQuant = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>

View File

@@ -624,9 +624,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq_permuteN<GemmConfig>(*bq_tensor_ptr);
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr);
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_permuted_host, GemmConfig::K_Tile / QuantGroupSize::kK);
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
}
else
{
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
}
}
else if constexpr(GemmConfig::PreshuffleQuant)
{

View File

@@ -24,20 +24,43 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
template <typename T>
auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
{
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int bqk_ = t->get_lengths()[0];
int n_ = t->get_lengths()[1];
const auto& lengths = t->get_lengths();
const size_t rank = lengths.size();
if(bqk_ % block_bq_k != 0)
// Validate block_bq_k divisibility based on rank
int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
if(bqk_dim < 0)
{
throw std::runtime_error("shuffle_bq needs a bqk of multiple times of block_bq_k.");
throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
std::to_string(rank));
}
if(bqk_dim % block_bq_k != 0)
{
throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
}
if(rank == 5)
{
// Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
static_cast<int>(lengths[1]),
static_cast<int>(lengths[2]),
static_cast<int>(lengths[3]),
bqk_dim / block_bq_k,
block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
}
else // rank == 2
{
// Handle 2D tensor: [bqk, n]
int n_ = lengths[1];
ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
ck_tile::HostTensor<T> t_view({n_, bqk_ / block_bq_k, block_bq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig, typename T>
@@ -57,7 +80,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}
template <typename GemmConfig, typename T>
auto shuffle_bq_permuteN(const ck_tile::HostTensor<T>& t)
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);

View File

@@ -431,7 +431,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)