diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index cabc0ec02c..b31c9736c2 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -103,6 +103,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode +{ + static constexpr bool PreshuffleQuant = true; +}; + template class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> { @@ -429,6 +434,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase(bq_bqk_bqn); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } + else if constexpr(GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor bq_shuffle_host = + ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK); + bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); + } else { bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data()); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index 9ca2451d3b..27d52230dd 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -111,7 +111,12 @@ using BPreshuffleBQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on