diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 76a6e37f0c..f8549d8816 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -356,15 +356,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase // here from nIter and warp id const index_t n_idx_of_warp = nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN; - const index_t row_index = - n_idx_of_warp / Traits::QuantGroupSize::kN; - if(threadIdx.x == 0) - { - printf("n_idx_of_warp: %d, row_index: %d, kQScale: %d\n", - n_idx_of_warp, - row_index, - kQScale.value); - } + const index_t row_index = n_idx_of_warp / Traits::QuantGroupSize::kN; return row_index * Traits::BQPerBlock + kQScale; } }(); @@ -377,6 +369,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); + // if(threadIdx.x % 64 == 0 && blockIdx.x == 0) + // { + // printf("warp_id: %d, mIter: %d, nIter: %d, kQScale: %d, reg_offset: %d, scale_reg_f: %f\n", + // get_warp_id(), + // mIter.value, + // nIter.value, + // kQScale.value, + // reg_offset, + // scale_reg_f); + // } static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) { c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 129096abdc..7beb21b7d0 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -684,9 +684,10 @@ struct QuantGemmKernel else if constexpr(kQuantType == QuantType::BQuantGrouped) { static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; return make_naive_tensor_view( bq_ptr, - make_tuple(kargs.N, kargs.QK_B), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), make_tuple(kargs.stride_BQ, 1), number{}, number<1>{}); @@ -907,9 +908,9 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; return make_tile_window( bq_pad_view, - make_tuple(number{}, + make_tuple(number{}, number{}), - {i_n, 0}); + {i_n / QuantGroupSize::kN, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index d2c0d5ced8..008e6e585d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -120,6 +120,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN; static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK; static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } @@ -258,7 +259,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV constexpr bool is_b_row_major = std::is_same_v; static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); - static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], "Bq block window has incorrect lengths for defined BqLayout!"); @@ -396,6 +397,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV bq_copy_dram_window, bq_dram_tile_window_step); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // { + // printf("---- pipeline loop %d ----\n", i); + // } block_gemm( c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 0245153019..270579344b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -222,10 +222,10 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding // sequence<1, 2>, // sequence<0, 0>>{}); return make_static_tile_distribution( - tile_distribution_encoding, + tile_distribution_encoding, tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<0>>, + tuple, sequence<0>>, + tuple, sequence<2>>, sequence<1, 2>, sequence<0, 0>>{}); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 9695e67c18..f35337ab11 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -29,7 +29,7 @@ struct GemmConfigBase // Default GEMM tile sizes for tests static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; @@ -399,7 +399,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{-0.5f, 0.5f}(a_m_k); ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); - ck_tile::FillUniformDistribution{0.001f, 0.01f}(bq_bqk_bqn); + // ck_tile::FillUniformDistribution{0.001f, 0.01f}(bq_bqk_bqn); + for (size_t i = 0; i < bq_bqk_bqn.size(); ++i) + { + bq_bqk_bqn.mData[i] = static_cast(0.0001f + 0.0001f * static_cast(i)); + } // Allocate device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); @@ -441,7 +445,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase