diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index f8549d8816..1db04d9e1f 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -352,7 +352,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase else { // FIXME: temporarily the tile distribution replicates all block's - // scales to all threads - need to figure out the index manually + // scales to all threads -> need to calculate the index manually // here from nIter and warp id const index_t n_idx_of_warp = nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN; @@ -369,16 +369,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); - // if(threadIdx.x % 64 == 0 && blockIdx.x == 0) - // { - // printf("warp_id: %d, mIter: %d, nIter: %d, kQScale: %d, reg_offset: %d, scale_reg_f: %f\n", - // get_warp_id(), - // mIter.value, - // nIter.value, - // kQScale.value, - // reg_offset, - // scale_reg_f); - // } + static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) { c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 008e6e585d..779fa0d9f9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -397,10 +397,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV bq_copy_dram_window, bq_dram_tile_window_step); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // { - // printf("---- pipeline loop %d ----\n", i); - // } block_gemm( c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 270579344b..05909c1d9c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -192,7 +192,7 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding { if constexpr(YPerQ == 1) { - // YPerQ == 1 implementation + // YPerQ == 1 implementation - each row of B has independent scale constexpr index_t X = XPerTile; constexpr index_t XR = 2; constexpr index_t Y0 = NIterPerWarp; @@ -211,18 +211,10 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - // YPerQ > 1 implementation + // YPerQ > 1 implementation - each group of YPerQ rows share the same scale // TODO: do not repeat everything to all threads - - // return make_static_tile_distribution( - // tile_distribution_encoding, - // tuple, sequence>, - // tuple, sequence<0, 1>>, - // tuple, sequence<1, 1>>, - // sequence<1, 2>, - // sequence<0, 0>>{}); return make_static_tile_distribution( - tile_distribution_encoding, + tile_distribution_encoding, tuple, sequence>, tuple, sequence<0>>, tuple, sequence<2>>, diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index 89306d8dc6..eb6114d08a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -33,22 +33,22 @@ using AQuantTypes = ::testing::Types< std::tuple, std::tuple, - // // PreshuffleQuant = false && TransposeC = true - // std::tuple, - // std::tuple, - // std::tuple, - // std::tuple, + // PreshuffleQuant = false && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple, - // // PreshuffleQuant = true && TransposeC = false - // std::tuple, - // std::tuple, - // std::tuple, - // std::tuple, + // PreshuffleQuant = true && TransposeC = false + std::tuple, + std::tuple, + std::tuple, + std::tuple, - // // PreshuffleQuant = true && TransposeC = true - // std::tuple, - // std::tuple, - // std::tuple, + // PreshuffleQuant = true && TransposeC = true + std::tuple, + std::tuple, + std::tuple, std::tuple >; // clang-format on @@ -56,30 +56,34 @@ using AQuantTypes = ::testing::Types< // clang-format off using BQuantTypes = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - // std::tuple, - // std::tuple, - // std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, - // std::tuple, - // std::tuple, - // std::tuple, - // std::tuple // 2d cases with grouping also on the n axis - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< std::tuple, - // std::tuple, - // std::tuple, - // std::tuple, + std::tuple, + std::tuple, + std::tuple, - // std::tuple, - // std::tuple, - // std::tuple, + std::tuple, + std::tuple, + std::tuple, std::tuple >; // clang-format on