diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index e54f50a716..ea04d83e29 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -523,8 +523,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + // ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + if(bq_tensor_ptr) + { + BQDataType value = 1.0f; + for(int i = 0; i < BQK; i++) + { + for(int j = 0; j < N / QuantGroupSize::kN; j += (16 / QuantGroupSize::kN)) + { + for(int k = 0; k < 16 / QuantGroupSize::kN; k++) + { + (*bq_tensor_ptr)(i, j + k) = value; + } + value += static_cast(1.0f); + } + } + } + // for(int i = 0; i < BQK; i++) + // { + // for(int j = 0; j < N / QuantGroupSize::kN; j++) + // { + // printf("%.2f ", (*bq_tensor_ptr)(i, j)); + // } + // printf("\n"); + // } } else { diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index c4332fa6aa..f8368eb2f9 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -111,8 +111,8 @@ auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; // 8 - int bqk_ = t.get_lengths()[0]; // 1 + int n_ = t.get_lengths()[1]; // 128 + int bqk_ = t.get_lengths()[0]; // 1 x 128 constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; // 128/16/4 = 2 @@ -120,9 +120,40 @@ auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) GemmConfig::N_Warp, GemmConfig::N_Warp_Tile / group_n, NRepeat, - bqk_}); //{1, 4, 16, 2, 1} + bqk_}); //{1, 4, 16, 2, 1}, group_n:16 {1, 4, 1, 2, 1} std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); //{1, 2, 4, 16, 1} + printf("I am inside bq_permuteN\n"); + printf("t.get_lengths(): %lu, %lu, %lu, %lu, %lu\n", + t_view.get_lengths()[0], + t_view.get_lengths()[1], + t_view.get_lengths()[2], + t_view.get_lengths()[3], + t_view.get_lengths()[4]); + for(int i = 0; i < static_cast(t.get_lengths()[0]); i++) + { + for(int j = 0; j < static_cast(t_view.get_lengths()[1]); j++) + { + for(int k = 0; k < static_cast(t_view.get_lengths()[2]); k++) + { + for(int l = 0; l < static_cast(t_view.get_lengths()[3]); l++) + { + for(int m = 0; m < static_cast(t_view.get_lengths()[4]); m++) + { + printf("t_view[%d][%d][%d][%d][%d]: %f\n", + i, + j, + k, + l, + m, + t_view(i, j, k, l, m)); + } + } + } + } + } + printf("I am inside bq_permuteN\n"); + return ck_tile::reference_permute( + t_view, {0, 3, 1, 2, 4}); // {1, 2, 4, 16, 1}, group_n 16 {1, 2, 4, 1, 1} } template diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index d4c538c872..b5646688a8 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -220,20 +220,21 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg float scale_reg_f = cvt_scale_to_fp32(scale_reg); // if(get_block_id() == 0 && get_thread_id() == 1) //{ - printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d, nIter: " - "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, " - "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n", - get_block_id(), - get_warp_id(), - get_thread_id(), - static_cast(nIter), - NWarp, - WG::kN, - static_cast(QuantGroupSize::kN), - static_cast(KPerBlockBQ), - static_cast(kQScale), - scale_reg_f, - reg_offset); + // printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d, + // nIter: " + // "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, " + // "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n", + // get_block_id(), + // get_warp_id(), + // get_thread_id(), + // static_cast(nIter), + // NWarp, + // WG::kN, + // static_cast(QuantGroupSize::kN), + // static_cast(KPerBlockBQ), + // static_cast(kQScale), + // scale_reg_f, + // reg_offset); //} static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 5a353714f5..9cfdf38d20 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1167,7 +1167,7 @@ struct QuantGemmKernel if(get_block_id() == 0 && get_thread_id() == 0) { bq_block_window.template print_tile_window_range( - 0, 1, 0, 16, "bq block window"); + 0, 1, 0, 128, "bq block window"); } return GemmPipeline{}.template operator()(a_block_window, b_block_window, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 4f792e9de8..9061090132 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC tile_distribution_encoding_pattern_bq; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 6cd8dc3e0f..6fc76e8694 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -169,9 +169,9 @@ struct tile_distribution_encoding_pattern_aq_transposed_c template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -255,16 +255,18 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding else if constexpr(XPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto X1 = NWarps / XR; // Warps per unique scale - constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension + constexpr auto XR = + XPerQ / WarpGemm::kN; // Scale replication factor //16/16 = 1 + constexpr auto X1 = NWarps / XR; // Warps per unique scale //4/1 = 4 + constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension //8/4 = 2 return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); + tile_distribution_encoding< + sequence, // 1, 1, 64 + tuple, sequence>, // 1, (2, 4) + tuple, sequence<0>>, //(1, 4, 1) (64) + tuple, sequence<2>>, + sequence<2, 1>, //(2, 1(in Y dimension)) + sequence<0, 0>>{}); } else // XPerQ > WarpGemm::kN * NWarps {