diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index e0e0a64416..62ca34b057 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 398a61f368..607c53d9af 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -540,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 16a0835b1d..313e449c7b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr constexpr index_t reg_offset = nIter; auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; // cross lane ops uint32_t scale_reg_dword; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 4f79361037..004fb18e0b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -280,12 +280,13 @@ struct QuantGemmKernel // Helper: Create Pre-shuffled Quantization Tensor Descriptor // =================================================================== template CK_TILE_DEVICE static auto - MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B) + MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B) { // Step 1: Calculate base BQ tensor dimensions // ---------------------------------------------------------- @@ -304,8 +305,9 @@ struct QuantGemmKernel // ---------------------------------------------------------- // Pad the X dimension to be a multiple of block_tile_size to ensure // each thread block can process complete tiles without edge cases - const auto block_tile_size = NPerBlock * KPerBlockBQ; - const auto bq_pad0_desc = transform_tensor_descriptor( + const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; + + const auto bq_pad0_desc = transform_tensor_descriptor( bq_desc, make_tuple(make_pass_through_transform(bq_y), make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), @@ -318,7 +320,7 @@ struct QuantGemmKernel // This separates the work into tiles that can be processed by // individual warps/waves const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = WarpTileN * KPerBlockBQ; + const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( @@ -813,12 +815,18 @@ struct QuantGemmKernel static_assert(std::is_same_v, "PreshuffleQuant with BQuantGrouped currently only supports " "ColumnMajor BQ layout"); + using QuantGroupSize = remove_cvref_t; return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlockBQ, GemmPipeline::NPerBlock, TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); + GemmPipeline::GetVectorSizeBQ()>( + bq_ptr, + ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN), + QuantGroupSize::kN, + kargs.QK_B); } else { @@ -879,13 +887,38 @@ struct QuantGemmKernel if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = + constexpr auto block_n = + TilePartitioner::NPerBlock / + QuantGroupSize::kN; // Number of N-dimension quantization groups per block + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at( + I1); // Number of N-dimension elements per warp + constexpr auto warp_per_group = + (QuantGroupSize::kN < + warp_n) // Determine how many warps share the same scale in N-dimension + ? (warp_n / QuantGroupSize::kN) + : (QuantGroupSize::kN / warp_n); + constexpr auto bqk_per_block = + TilePartitioner::KPerBlock / + QuantGroupSize::kK; // Number of K-dimension quantization groups per block + constexpr auto + tile_window_width = // The pre-shuffled layout flattens warp_n × + // bqk_per_block scales per row, Padded up to warp_size + // to ensure coalesced memory access. ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; + + // Adapts based on fine vs coarse quantization granularity: + // - Fine-grained (QuantGroupSize::kN < warp_n): + // Multiple quant groups per warp → fewer rows needed per block. + // height = block_n / warp_per_group + // + // - Coarse-grained (QuantGroupSize::kN >= warp_n): + // Each row represents one quant group. + // height = block_n + constexpr auto tile_window_height = + (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; + auto block_n_idx = + i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a + // block index. return make_tile_window( bq_tensor_view, @@ -1125,596 +1158,6 @@ struct QuantGemmKernel return true; } - template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) - { - - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& aq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; - const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; - const auto aq_desc = - make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), - make_tuple(aq_x, 1), - number{}, - number<1>{}); - - const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; - const auto aq_pad0_desc = transform_tensor_descriptor( - aq_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = - GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; - const auto wave_tile_count_x = - ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); - - const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( - aq_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto aq_pad1_desc = transform_tensor_descriptor( - aq_unmerge_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_pass_through_transform(wave_tile_count_x), - make_right_pad_transform( - wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - const auto pad_wave_size = - ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); - const auto aq_merge_pad1_desc = transform_tensor_descriptor( - aq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), - make_pass_through_transform(pad_wave_size)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view(aq_ptr, aq_merge_pad1_desc); - } - else if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - !PreshuffleQuant) - { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); - } - else // Column major AQ - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions - make_tuple(kargs.stride_AQ, 1), // Same stride pattern - number{}, - number<1>{}); - } - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, 0), // broadcasting over n - number<1>{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(PreshuffleB) - { - index_t kFlatK = GemmPipeline::flatKPerWarp * - (splitk_batch_offset.splitted_k / - GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& bq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - bq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(0, 1), // broadcasting over m - number<1>{}, - number<1>{}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v, - "PreshuffleQuant with BQuantGrouped currently only supports " - "ColumnMajor BQ layout"); - - return MakePreshuffledQuantTensorView< - GemmPipeline::KPerBlockBQ, - GemmPipeline::NPerBlock, - TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); - } - else - { - using QuantGroupSize = remove_cvref_t; - - if constexpr(std::is_same_v) - { - // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] - // Dimensions: [K/QuantGroupK, N/QuantGroupN] - // Strides: [N/QuantGroupN, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), - integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), - number{}, - number<1>{}); - } - else - { - static_assert(std::is_same_v); - // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] - // Dimensions: [N/QuantGroupN, K/QuantGroupK] - // Strides: [K/QuantGroupK, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), - integer_divide_ceil(kargs.K, QuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), - number{}, - number<1>{}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple( - a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& aq_pad_view = [&]() { return views.at(I1); }(); - - const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I2); - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& bq_pad_view = [&]() { return views.at(I3); }(); - - // TODO vector write in for C in ColMajor - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I4); - if constexpr(std::is_same_v) - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(PreshuffleB) - { - - return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view); - } - else - { - return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - - const auto& a_pad_view = views.at(I0); - const auto& aq_pad_view = views.at(I1); - const auto& b_pad_view = views.at(I2); - const auto& bq_pad_view = views.at(I3); - const auto& c_pad_view = views.at(I4); - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& aq_block_window = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_m / warp_m; - auto block_m_idx = i_m / block_m; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {block_m_idx * tile_window_height, 0}); - } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) - { - using QuantGroupSize = remove_cvref_t; - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto block_m = TilePartitioner::MPerBlock; - if constexpr(std::is_same_v) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else // Column major AQ - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {0, i_m}); - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return nullptr; // TODO: use some other "empty" type? - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(PreshuffleB) - { - - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto& bq_block_window = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(bq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - using QuantGroupSize = remove_cvref_t; - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; - - return make_tile_window( - bq_pad_view, - make_tuple(number{}, number{}), - {block_n_idx * tile_window_height, 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); - } - else - { - static_assert(std::is_same_v); - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - else - { - return nullptr; // TODO: use some other "empty" type here - } - }(); - - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple( - a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window); - } - /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 39f0cbdbd3..a4bba6cf76 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - constexpr index_t VecLoadSize = GetVectorSizeBQ(); constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; using WarpTile = typename Problem::BlockGemmShape::WarpTile; @@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC BlockSize, NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), - VecLoadSize, + Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); @@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC KPerBlockBQ, // Logical K dimension NPerBlockBQ, // Logical N dimension Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index b43066cdc5..13d400d5fc 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), - 0) + (PreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) : is_bq_row_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 0ec8942426..34f815ed27 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -192,6 +192,7 @@ template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern @@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding static_assert(num_warps == MWarps * NWarps * KWarps); static_assert(KWarps == 1); - /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) - /// - /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative - /// to warp dimensions. - /// - /// Three distinct distribution patterns are handled: - /// - /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): - /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast - /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp - /// - /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): - /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN - /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 - /// - /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): - /// - Quantization group spans multiple warps - /// - All warps share the same scale value - /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale - /// - /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { // Preshuffle only supported for ColumnMajor currently @@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding if constexpr(PreshuffleQuant) { - // ColumnMajor only for preshuffle - constexpr index_t X1 = warp_size; - constexpr index_t X0 = NPerTile / warp_size; - constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = KPerTile / Y1; + // ============================================================================= + // PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION + // ============================================================================= + // For pre-shuffled quantization, the BQ scale tensor has been reorganized + // (pre-shuffled) to optimize memory access patterns during dequantization. + // + // Tile Dimensions: + // - K-axis (Y in encoding): Corresponds to the K-dimension iteration + // - N-axis (X in encoding): Flattened scale index combining N and K groups + // + // The encoding distributes work across threads such that each thread loads + // the correct pre-shuffled scale for its corresponding B-matrix elements. + // ============================================================================= + if constexpr(NPerQ <= WarpGemm::kN) + { + // ========================================================================= + // CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN) + // ========================================================================= + // Multiple quantization scales exist within a single warp's N-dimension. + // Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp. + // + // Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256 + // → 2 scales per warp in N, 2 K-groups per block + constexpr auto N1 = BlockGemmShape::kK / + KPerQ; // Number of K-dimension quantization groups per block, + // Each K-group of KPerQ elements shares the same scale. + constexpr auto N0 = + WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ + // <= WarpGemm::kN, each warp handles multiple scales. + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension + constexpr auto NR0 = + warp_size / + (N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization + constexpr auto K1 = NWarps; // Number of warps distributed along this dimension + constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile + constexpr auto KR = 1; // No replication in K-dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<1>>, - sequence<1, 2>, - sequence<0, 0>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2, 0>>, + tuple, sequence<1, 0, 2, 1, 3>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else if constexpr(NPerQ < WarpGemm::kN * NWarps) + { + // ========================================================================= + // CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN * + // NWarps) + // ========================================================================= + // Each warp handles exactly one quantization scale in N-dimension. + // Some warps share the same scale (KR > 1 creates warp grouping). + // + // Example: NPerQ=32, WarpGemm::kN=16, NWarps=4 + // → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups) + + constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale + constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales) + constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN) + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ) + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<1, 0, 2, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + // ========================================================================= + // CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps) + // ========================================================================= + // The quantization group spans ALL warps in N-dimension. + // All warps share the same scale value for their N-tiles. + // + // Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 + // → 128 >= 16*4=64, so all 4 warps use the same scale + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Minimal (1) since scale is shared across N + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = 32; // Fixed broadcast size + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<2, 0, 3, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } } else { + /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) + /// + /// This function determines the optimal thread distribution pattern for loading and + /// applying quantization scales to the B matrix based on the quantization group size + /// (NPerQ) relative to warp dimensions. + /// + /// Three distinct distribution patterns are handled: + /// + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): + /// - Multiple quantization groups exist within a single warp's N-dimension + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale + /// broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): + /// - Each warp handles exactly one quantization scale + /// - Scales are distributed across warps with replication factor XR = NPerQ / + /// WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): + /// - Quantization group spans multiple warps + /// - All warps share the same scale value + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// + /// @return A static tile distribution encoding for the BQ scale tensor if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 43f37ec4d8..e4de7e4211 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t NPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN); static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else