diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 085dcaa94e..f3ddbae9fe 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -18,7 +18,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # gemm_bquant_quantgrouped_fp8i4.cpp # gemm_bquant_quantgrouped_bf16mxfp4.cpp # gemm_bquant_quantgrouped_bf8.cpp - # gemm_bquant_quantgrouped_fp8.cpp + gemm_bquant_quantgrouped_fp8.cpp # gemm_bquant_quantgrouped_preshuffleb.cpp gemm_bquant_quantgrouped_preshufflequant.cpp # gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 280029033b..8ad5529d19 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ @@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index 2750112683..e1c8ed3273 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -4,8 +4,8 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigPreshuffleBQuantPrefill; // GemmConfigPreshuffleQuantDecode; - // //GemmConfigPreshuffleBQuantPrefill; +using GemmConfig = GemmConfigPreshuffleQuantDecode; +// //GemmConfigPreshuffleBQuantPrefill; void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) @@ -24,31 +24,33 @@ void bquant_quantgrouped_preshufflequant_instance_factory( ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "non-preshuffleb", - "preshufflequant", - "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + // lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", + // "1x8x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // using QuantGroupSize = ck_tile::QuantGroupShape>; + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings({"fp8", + // "bquant", + // "non-preshuffleb", + // "preshufflequant", + // "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = + // decltype(GemmQuantTypeConfig{}); + // using QuantGroupSize = ck_tile::QuantGroupShape>; + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", @@ -75,20 +77,19 @@ void bquant_quantgrouped_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - // lut[hash_multiple_strings({"fp8", - // "bquant", - // "non-preshuffleb", - // "preshufflequant", - // "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = - // decltype(GemmQuantTypeConfig{}); - // using QuantGroupSize = ck_tile::QuantGroupShape>; - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - //}; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; // lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", // "1x1x128"})] = // [](const ck_tile::ArgParser& arg_parser) { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index de9d691a01..2f5c133366 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -89,8 +89,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) // std::unordered_map>& lut); // void aquant_quantgrouped_preshufflequant_instance_factory( // std::unordered_map>& lut); -// void bquant_quantgrouped_fp8_instance_factory( -// std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); // void bquant_quantgrouped_bf8_instance_factory( // std::unordered_map>& lut); // void bquant_quantgrouped_fp8i4_instance_factory( @@ -126,7 +126,7 @@ int main(int argc, char* argv[]) std::unordered_map> lut; // aquant_quantgrouped_instance_factory(lut); // aquant_quantgrouped_preshufflequant_instance_factory(lut); - // bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); // bquant_quantgrouped_bf8_instance_factory(lut); // bquant_quantgrouped_fp8i4_instance_factory(lut); // bquant_quantgrouped_bf8i4_instance_factory(lut); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index eb2ad448ae..92eff8436f 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -349,7 +349,24 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase if constexpr(PreshuffleQuant) { - constexpr index_t reg_offset = nIter; + // constexpr index_t reg_offset = nIter; + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::QuantGroupSize::kN >= + (NWarp * WarpGemm::kN)) + { + if constexpr(Traits::NPerBlock == + GemmTraits::QuantGroupSize::kN) + return kQScale; + else + return nIter; // for prefill needs kQscale, for decode needs + // nIter + } + else + { + return nIter; + } + }(); + auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; @@ -373,13 +390,14 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); // printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d, - // lane_id(): " + // reg_offset: %d, lane_id(): " // "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, " // "scale_reg_f: %f\n", // get_block_id(), // get_warp_id(), // get_thread_id(), // static_cast(nIter), + // reg_offset, // __lane_id(), // static_cast(kQScale), // pull_from_lane, diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 6d8aaa2c45..2d9cff3c87 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -346,7 +346,7 @@ struct QuantGemmKernel if(get_block_id() == 0 && get_thread_id() == 0) { - printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:" + printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockBQ: %d, KPerBlockBQ: %d, wave_tile_size:" "%d, wave_tile_count_x: %d\n", pad_bq_x, WarpTileN, @@ -1119,7 +1119,9 @@ struct QuantGemmKernel { static_assert(std::is_same_v); constexpr auto block_n = - TilePartitioner::NPerBlock / QuantGroupSize::kN; // 128/32 = 4 + (QuantGroupSize::kN <= TilePartitioner::NPerBlock) + ? TilePartitioner::NPerBlock / QuantGroupSize::kN + : QuantGroupSize::kN / TilePartitioner::NPerBlock; // 128/32 = 4 constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); // 16 constexpr auto warpPerGroup = (QuantGroupSize::kN < warp_n) @@ -1133,29 +1135,47 @@ struct QuantGemmKernel ? block_n / warpPerGroup : block_n; // 4 / 2 = 2 - auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2 - - if(get_thread_id() == 0) + auto block_n_idx = + i_n / + TilePartitioner::NPerBlock; // 0,1,2 (i_n - TilePartitioner::NPerBlock) / + // TilePartitioner::NPerBlock + // For decode shapes GN: 128, Blocks needs to access 0,0,1,1,2,2 ... + if(QuantGroupSize::kN > TilePartitioner::NPerBlock) { - printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n"); - printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, " - "bqk_per_block: %d, block_n_idx: %d, " - "tile_window_width: %d, tile_window_height: %d, i_n: %d\n", - get_block_id(), - static_cast(block_n), - static_cast(warp_n), - static_cast(warpPerGroup), - static_cast(bqk_per_block), - static_cast(block_n_idx), - tile_window_width, - static_cast(tile_window_height), - static_cast(i_n)); + block_n_idx = block_n_idx >> 1; + } + // if(get_thread_id() == 0) + // { + // printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n"); + // printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, " + // "bqk_per_block: %d, block_n_idx: %d, " + // "tile_window_width: %d, tile_window_height: %d, i_n: %d\n", + // get_block_id(), + // static_cast(block_n), + // static_cast(warp_n), + // static_cast(warpPerGroup), + // static_cast(bqk_per_block), + // static_cast(block_n_idx), + // tile_window_width, + // static_cast(tile_window_height), + // static_cast(i_n)); + // } + if(QuantGroupSize::kN > TilePartitioner::NPerBlock) + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, number{}), + {block_n_idx, 0}); + } + else + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, number{}), + {block_n_idx * tile_window_height, + 0}); // normally needs block_n_idx * tile_window_height, for decode GN + // : 128 needs 0,0, 1,1, 2,2 ... } - - return make_tile_window( - bq_pad_view, - make_tuple(number{}, number{}), - {block_n_idx * tile_window_height, 0}); } else { @@ -1255,15 +1275,15 @@ struct QuantGemmKernel { n = kargs.N; } - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n"); - // To print Tile window after bq_pad0_desc - // bq_block_window.template print_tile_window_range( - // 0, 128, 0, 2, "bq block window"); - bq_block_window.template print_tile_window_range( - 0, 8, 0, 32, "bq block window"); - } + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n"); + // // To print Tile window after bq_pad0_desc + // // bq_block_window.template print_tile_window_range( + // // 0, 128, 0, 2, "bq block window"); + // bq_block_window.template print_tile_window_range( + // 0, 8, 0, 16, "bq block window"); + // } return GemmPipeline{}.template operator()( a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index c570d4a131..2ede432a89 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -29,11 +29,11 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase= 1, "NPerBlock must be >= QuantGroupSize"); + // static_assert(NPerBlockBQ == 1, "NPerBlock must be >= QuantGroupSize"); static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); - static_assert(NPerBlock % QuantGroupSize::kN == 0, - "NPerBlock must be a multiple of QuantGroupSize::kN"); + // static_assert(NPerBlock % QuantGroupSize::kN == 0, + // "NPerBlock must be a multiple of QuantGroupSize::kN"); static_assert(KPerBlock % QuantGroupSize::kK == 0, "KPerBlock must be a multiple of QuantGroupSize::kK"); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index fbfa281465..7230e21652 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -65,7 +65,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index eea34d998a..fab7221819 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -297,14 +297,27 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2 + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 1 constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1 constexpr auto N2 = 1; - constexpr auto NR1 = 32; // 32 - constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1 + constexpr auto NR1 = 32; // NPerQ; // 32 + constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*1*1*32)=1 + constexpr auto K1 = KPerTile; + + if(get_block_id() == 0 && get_thread_id() == 0) + { + // Debug print to verify values + printf("PreshuffleQuant Fine-grained: KPerQ: %d, NPerQ: %d, N1=%d, NR0=%d, " + "KPerTile: %d \n", + KPerQ, + NPerQ, + N1, + NR0, + KPerTile); + } return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, sequence>, tuple, sequence<0, 2, 0, 2>>, tuple, sequence<2, 0, 3, 1>>, sequence<1, 2>, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 0005eab52f..7e1b6de3a1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -73,7 +73,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase