diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index a737c7a480..d1ccf5fb8f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -246,6 +246,8 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill : public GemmConfigBase static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; static constexpr bool PreshuffleQuant = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index fe75ed3691..4389744acf 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -624,9 +624,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) { printf("Preshuffle BQ with TiledMMAPermuteN \n"); - ck_tile::HostTensor bq_shuffle_host = - ck_tile::shuffle_bq_permuteN(*bq_tensor_ptr); - bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); + ck_tile::HostTensor bq_permuted_host = + ck_tile::bq_permuteN(*bq_tensor_ptr); + + if constexpr(GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor bq_shuffle_host = + ck_tile::shuffle_bq(&bq_permuted_host, GemmConfig::K_Tile / QuantGroupSize::kK); + bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); + } + else + { + bq_dev_buf_ptr->ToDevice(bq_permuted_host.data()); + } } else if constexpr(GemmConfig::PreshuffleQuant) { diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index cec4a3a708..3023d330ed 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -24,20 +24,43 @@ auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) template auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) { - if(t->get_lengths().size() != 2) - { - throw std::runtime_error("Host tensor is not rank 2 tensor."); - } - int bqk_ = t->get_lengths()[0]; - int n_ = t->get_lengths()[1]; + const auto& lengths = t->get_lengths(); + const size_t rank = lengths.size(); - if(bqk_ % block_bq_k != 0) + // Validate block_bq_k divisibility based on rank + int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1; + + if(bqk_dim < 0) { - throw std::runtime_error("shuffle_bq needs a bqk of multiple times of block_bq_k."); + throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " + + std::to_string(rank)); + } + + if(bqk_dim % block_bq_k != 0) + { + throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k."); + } + + if(rank == 5) + { + // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk] + ck_tile::HostTensor t_view({static_cast(lengths[0]), + static_cast(lengths[1]), + static_cast(lengths[2]), + static_cast(lengths[3]), + bqk_dim / block_bq_k, + block_bq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5}); + } + else // rank == 2 + { + // Handle 2D tensor: [bqk, n] + int n_ = lengths[1]; + ck_tile::HostTensor t_view({n_, bqk_dim / block_bq_k, block_bq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {1, 0, 2}); } - ck_tile::HostTensor t_view({n_, bqk_ / block_bq_k, block_bq_k}); - std::copy(t->begin(), t->end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {1, 0, 2}); } template @@ -57,7 +80,7 @@ auto shuffle_b(const ck_tile::HostTensor& t) } template -auto shuffle_bq_permuteN(const ck_tile::HostTensor& t) +auto bq_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index b31c9736c2..f733729269 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -431,7 +431,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase bq_shuffle_host = - ck_tile::shuffle_bq_permuteN(bq_bqk_bqn); + ck_tile::bq_permuteN(bq_bqk_bqn); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } else if constexpr(GemmConfig::PreshuffleQuant)