diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 3134eb62d9..085dcaa94e 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -16,7 +16,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # gemm_aquant_quantgrouped_preshufflequant.cpp # gemm_bquant_quantgrouped_bf8i4.cpp # gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp + # gemm_bquant_quantgrouped_bf16mxfp4.cpp # gemm_bquant_quantgrouped_bf8.cpp # gemm_bquant_quantgrouped_fp8.cpp # gemm_bquant_quantgrouped_preshuffleb.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index 6fdf5827fa..a48e013b51 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -35,30 +35,45 @@ void bquant_quantgrouped_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; // lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", // "1x1x128"})] = // [](const ck_tile::ArgParser& arg_parser) { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index de9d691a01..9c8d8eba50 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[]) "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " "bf8i4 or bf16fp4") - .insert("warmup", "1", "Number of iterations before benchmarking the kernel") - .insert("repeat", "0", "Number of iterations to benchmark the kernel") + .insert("warmup", "50", "Number of iterations before benchmarking the kernel") + .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "SplitK value") .insert("device", "0", "Device id that will be used to run the kernel") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("flush_cache", "true", "Flush cache before running the kernel") - .insert("rotating_count", "0", "Rotating count") + .insert("rotating_count", "1000", "Rotating count") .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") .insert("preshuffleb", "false", "Enable preshuffle of tensor B") .insert("preshufflequant", "false", "Enable preshuffle of quant tensor") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index f1f0aa25e6..077b069992 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -357,9 +357,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, << " C_Type = " << ck_tile::DataTypeTraits::name << " QuantMode = " << quant_type_to_string(QuantMode) << " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : " - << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : " - << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " PreshuffleB = \n" + << (GemmConfig::PreshuffleB ? "true" : "false") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 84f3822e0c..157c9d38c2 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -349,6 +349,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase constexpr index_t reg_offset = nIter; auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; // cross lane ops uint32_t scale_reg_dword; @@ -368,19 +369,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); - printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d, lane_id(): " - "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, " - "scale_reg_f: %f\n", - get_block_id(), - get_warp_id(), - get_thread_id(), - static_cast(nIter), - __lane_id(), - static_cast(kQScale), - pull_from_lane, - scale_reg, - scale_reg_f); - + // printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d, + // lane_id(): " + // "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, " + // "scale_reg_f: %f\n", + // get_block_id(), + // get_warp_id(), + // get_thread_id(), + // static_cast(nIter), + // __lane_id(), + // static_cast(kQScale), + // pull_from_lane, + // scale_reg, + // scale_reg_f); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( [&](auto c_row) { c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index acf3598c1a..94a4ade13c 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -298,15 +298,15 @@ struct QuantGemmKernel const auto bq_x = N * KPerBlockBQ; // 2x2 = 4 const auto bq_y = QK_B / KPerBlockBQ; // 4/2 = 2 - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("N:%d, QK_B:%d\n", N, QK_B); - printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n", - bq_x, - bq_y, - GetVectorSizeBQ, - KPerBlockBQ); - } + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // printf("N:%d, QK_B:%d\n", N, QK_B); + // printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n", + // bq_x, + // bq_y, + // GetVectorSizeBQ, + // KPerBlockBQ); + // } const auto bq_desc = make_naive_tensor_descriptor(make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), @@ -319,10 +319,10 @@ struct QuantGemmKernel // each thread block can process complete tiles without edge cases const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 2x2 = 4 - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("block_tile_size:%d \n", block_tile_size); - } + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // printf("block_tile_size:%d \n", block_tile_size); + // } const auto bq_pad0_desc = transform_tensor_descriptor( bq_desc, @@ -337,22 +337,25 @@ struct QuantGemmKernel // Split the X dimension into [wave_tile_count_x, wave_tile_size] // This separates the work into tiles that can be processed by // individual warps/waves - const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4 - const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; // 32/16 x 2 = 4 = 2 + const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4 + const auto wave_tile_size = + ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1 /*QN_B/WarpTileN*/) * + KPerBlockBQ; // 32/16 x 2 = 4 = 2 const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2 - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: " - "%d, wave_tile_count_x: %d\n", - pad_bq_x, - WarpTileN, - NPerBlockBQ, - KPerBlockBQ, - wave_tile_size, - wave_tile_count_x); - } + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: + // " + // "%d, wave_tile_count_x: %d\n", + // pad_bq_x, + // WarpTileN, + // NPerBlockBQ, + // KPerBlockBQ, + // wave_tile_size, + // wave_tile_count_x); + // } const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( bq_pad0_desc, @@ -383,13 +386,16 @@ struct QuantGemmKernel // where merged_outer_dim = bq_y * wave_tile_count_x // This layout facilitates efficient block-to-data mapping const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("pad_wave_size:%d\n", pad_wave_size); - } + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // printf("pad_wave_size:%d\n", pad_wave_size); + // printf("Final bq tensor lengths: %d x %d \n", + // bq_y * wave_tile_count_x, + // pad_wave_size); + // } const auto bq_merge_pad1_desc = transform_tensor_descriptor( bq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 2 + make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 4 make_pass_through_transform(pad_wave_size)), // 64 make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -1115,13 +1121,33 @@ struct QuantGemmKernel if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + constexpr auto block_n = + TilePartitioner::NPerBlock / QuantGroupSize::kN; // 64 / 32 = 2 + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; constexpr auto tile_window_width = ck_tile::integer_least_multiple( - warp_n * bqk_per_block, get_warp_size()); // 128 - constexpr auto tile_window_height = block_n / warp_n; // 2 - auto block_n_idx = i_n / block_n; + warp_n * bqk_per_block, get_warp_size()); // 128 + constexpr auto tile_window_height = + min(block_n, + TilePartitioner::BlockGemmShape::BlockWarps::at( + I1)); // block_n / warp_n; // 2 / 4 = 0 + auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2 + + // if(get_thread_id() == 0) + // { + // printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n"); + // printf("block_id: %d, block_n: %d, warp_n: %d, bqk_per_block: %d, + // block_n_idx: %d, " + // "tile_window_width: %d, tile_window_height: %d, i_n: %d\n", + // get_block_id(), + // static_cast(block_n), + // static_cast(warp_n), + // static_cast(bqk_per_block), + // static_cast(block_n_idx), + // tile_window_width, + // static_cast(tile_window_height), + // static_cast(i_n)); + // } return make_tile_window( bq_pad_view, @@ -1226,15 +1252,15 @@ struct QuantGemmKernel { n = kargs.N; } - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n"); - // To print Tile window after bq_pad0_desc - // bq_block_window.template print_tile_window_range( - // 0, 128, 0, 2, "bq block window"); - bq_block_window.template print_tile_window_range( - 0, 8, 0, 64, "bq block window"); - } + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n"); + // // To print Tile window after bq_pad0_desc + // // bq_block_window.template print_tile_window_range( + // // 0, 128, 0, 2, "bq block window"); + // bq_block_window.template print_tile_window_range( + // 0, 8, 0, 64, "bq block window"); + // } return GemmPipeline{}.template operator()( a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 2c191cc2b4..193494d542 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -196,7 +196,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), - 0) + (PreshuffleQuant) + ? make_array(((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) : is_bq_row_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 67baffaa68..eea34d998a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -245,56 +245,70 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding if constexpr(NPerQ <= WarpGemm::kN) { constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2 - constexpr auto N0 = WarpGemm::kN / NPerQ; //BlockGemmShape::kN / KPerQ; // 1 + constexpr auto N0 = WarpGemm::kN / NPerQ; // BlockGemmShape::kN / KPerQ; // 1 constexpr auto N2 = 1; - constexpr auto NR1 = NPerQ; // 16 + constexpr auto NR1 = NPerQ; // 16 constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*16)=2 - constexpr auto K1 = NWarps; // 4 - constexpr auto K0 = KPerTile / K1; // 1 + constexpr auto K1 = NWarps; // 4 + constexpr auto K0 = KPerTile / K1; // 1 constexpr auto KR = 1; - - return make_static_tile_distribution( tile_distribution_encoding< sequence, tuple, sequence>, - tuple, sequence<0, 2, 0, 2, 0>>, // (Mwarp, Nwarp),(XR0, X0, XR1, X1, YR) + tuple, sequence<0, 2, 0, 2, 0>>, // (Mwarp, Nwarp),(XR0, X0, + // XR1, X1, YR) tuple, sequence<1, 0, 2, 1, 3>>, // (1, 4), (2, 1, 16, 2, 1) sequence<1, 2>, sequence<0, 2>>{}); } - else if constexpr(NPerQ <= WarpGemm::kN * NWarps) + else if constexpr(NPerQ < WarpGemm::kN * NWarps) { - constexpr auto KR = NPerQ / WarpGemm::kN; // Scale replication factor 32/16 = 2 - constexpr auto K1 = NWarps / KR; // Warps per unique scale 4/2 = 2 - constexpr auto K0 = KPerTile / K1; // Iterations to cover N dimension 4/2 = 2 + constexpr auto KR = NPerQ / WarpGemm::kN; // Scale replication factor 64/16 = 4 + constexpr auto K1 = NWarps / KR; // Warps per unique scale 4/4 = 1 + constexpr auto K0 = KPerTile / K1; // Iterations to cover N dimension 4/1 = 4 constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2 - constexpr auto N0 = 1; //NPerQ/WarpGemm::kN; // 2 + constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1 constexpr auto N2 = 1; - constexpr auto NR1 = NPerQ; // 32 + constexpr auto NR1 = NPerQ; // 32 constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1 - if(get_block_id() == 0 && get_thread_id() == 0) - { - // Debug print to verify values - printf("PreshuffleQuant Medium-grained: MWarps: %d, K1=%d, KR=%d, get_warp_size(): %d, K0=%d, N0=%d\n", - MWarps, - K1, - KR, - get_warp_size(), - K0, - N0); - } - + // if(get_block_id() == 0 && get_thread_id() == 0) + // { + // // Debug print to verify values + // printf("PreshuffleQuant Medium-grained: MWarps: %d, K1=%d, KR=%d, + // get_warp_size(): %d, K0=%d, N0=%d\n", + // MWarps, + // K1, + // KR, + // get_warp_size(), + // K0, + // N0); + // } + return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 2, 0, 2>>, - tuple, sequence<1, 0, 2, 1>>, - sequence<1, 2>, - sequence<0, 2>>{}); - + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<1, 0, 2, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2 + constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1 + constexpr auto N2 = 1; + constexpr auto NR1 = 32; // 32 + constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1 + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<2, 0, 3, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); } } else