diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 40a4166126..bcc0fcc044 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -11,9 +11,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_quant.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp - gemm_bquant_quantgrouped_bf8i4.cpp - gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf8.cpp + # gemm_bquant_quantgrouped_bf8i4.cpp + # gemm_bquant_quantgrouped_fp8i4.cpp + # gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb.cpp gemm_bquant_quantgrouped_preshufflequant.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp index 898316fa6b..acfdb92d63 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -9,51 +9,94 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + + // lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", + // "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // using QuantGroupSize = ck_tile::QuantGroupShape>; + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings( + // {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // using QuantGroupSize = ck_tile::QuantGroupShape>; + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings( + // {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // using QuantGroupSize = ck_tile::QuantGroupShape>; + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp index e26edb6501..5d49bd46a7 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp @@ -21,37 +21,39 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + // lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] + // = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] + // = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index 82967d5be2..8771447eed 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -21,39 +21,40 @@ void bquant_quantgrouped_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; + // lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", + // "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings( + // {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; + // lut[hash_multiple_strings( + // {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + // [](const ck_tile::ArgParser& arg_parser) { + // using TypeConfig = decltype(GemmQuantTypeConfig{}); + // return run_gemm_example_prec_type, + // TypeConfig, + // QuantGroupSize, + // ck_tile::QuantType::BQuantGrouped>(arg_parser); + // }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 216e3cdfb8..fd1b87cda2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -17,9 +17,9 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("h", "false", "Print help message") - .insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") + .insert("m", "128", "m dimension") + .insert("n", "128", "n dimension") + .insert("k", "128", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row or Column") .insert("b_layout", "C", "B tensor data layout - Row or Column") .insert("bq_layout", "C", "Bq tensor data layout - Row or Column") @@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[]) "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " "or bf8i4") - .insert("warmup", "50", "Number of iterations before benchmarking the kernel") - .insert("repeat", "1000", "Number of iterations to benchmark the kernel") + .insert("warmup", "1", "Number of iterations before benchmarking the kernel") + .insert("repeat", "0", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "SplitK value") .insert("device", "0", "Device id that will be used to run the kernel") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("flush_cache", "true", "Flush cache before running the kernel") - .insert("rotating_count", "1000", "Rotating count") + .insert("rotating_count", "0", "Rotating count") .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") .insert("preshuffleb", "false", "Enable preshuffle of tensor B") .insert("preshufflequant", "false", "Enable preshuffle of quant tensor") @@ -91,12 +91,12 @@ void aquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_fp8_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf8i4_instance_factory( - std::unordered_map>& lut); +// void bquant_quantgrouped_bf8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_fp8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_bf8i4_instance_factory( +// std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_preshufflequant_instance_factory( @@ -125,9 +125,9 @@ int main(int argc, char* argv[]) aquant_quantgrouped_instance_factory(lut); aquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_fp8_instance_factory(lut); - bquant_quantgrouped_bf8_instance_factory(lut); - bquant_quantgrouped_fp8i4_instance_factory(lut); - bquant_quantgrouped_bf8i4_instance_factory(lut); + // bquant_quantgrouped_bf8_instance_factory(lut); + // bquant_quantgrouped_fp8i4_instance_factory(lut); + // bquant_quantgrouped_bf8i4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); bquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index fcc5c00327..0d2d5d4d23 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -210,7 +210,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 4389744acf..baf83bfc74 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -481,7 +481,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, + 3.0f /*, fill_seed(gen)*/}(b_k_n); } ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); @@ -543,7 +544,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, b_k_n.SetZero(); bq_tensor_ptr->SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -600,6 +600,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); + printf("b_k_n_dev.get_lengths(): %lu, %lu, %lu, %lu, %lu, %lu, %lu\n", + b_k_n_dev.get_lengths()[0], + b_k_n_dev.get_lengths()[1], + b_k_n_dev.get_lengths()[2], + b_k_n_dev.get_lengths()[3], + b_k_n_dev.get_lengths()[4], + b_k_n_dev.get_lengths()[5], + b_k_n_dev.get_lengths()[6]); } else { @@ -624,8 +632,44 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) { printf("Preshuffle BQ with TiledMMAPermuteN \n"); + for(int i = 0; i < static_cast((*bq_tensor_ptr).get_lengths()[0]); i++) + { + for(int j = 0; j < static_cast((*bq_tensor_ptr).get_lengths()[1]); j++) + { + printf("(*bq_tensor_ptr)[%d][%d]: %f\n", i, j, (*bq_tensor_ptr)(i, j)); + } + } ck_tile::HostTensor bq_permuted_host = ck_tile::bq_permuteN(*bq_tensor_ptr); + printf("bq_permuted_host.get_lengths(): %lu, %lu, %lu, %lu, %lu\n", + bq_permuted_host.get_lengths()[0], + bq_permuted_host.get_lengths()[1], + bq_permuted_host.get_lengths()[2], + bq_permuted_host.get_lengths()[3], + bq_permuted_host.get_lengths()[4]); + for(int i = 0; i < static_cast(bq_permuted_host.get_lengths()[0]); i++) + { + for(int j = 0; j < static_cast(bq_permuted_host.get_lengths()[1]); j++) + { + for(int k = 0; k < static_cast(bq_permuted_host.get_lengths()[2]); k++) + { + for(int l = 0; l < static_cast(bq_permuted_host.get_lengths()[3]); l++) + { + for(int m = 0; m < static_cast(bq_permuted_host.get_lengths()[4]); + m++) + { + printf("bq_permuted_host[%d][%d][%d][%d][%d]: %f\n", + i, + j, + k, + l, + m, + bq_permuted_host(i, j, k, l, m)); + } + } + } + } + } if constexpr(GemmConfig::PreshuffleQuant) { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ea459417d2..1efd5d18eb 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1197,7 +1197,7 @@ struct tile_window_with_static_lengths using ThreadBuf = thread_buffer; auto buf = tensor_view.template get_vectorized_elements(coord, 0); auto value = buf.at(number<0>{}); // Extract first element from thread buffer - printf(" %s[%d,%d] = %f", label, i, j, type_convert(value)); + printf(" %s[%d,%d] = %f\n", label, i, j, type_convert(value)); } printf("\n"); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index ee56a2f988..967caa1ab7 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -28,7 +28,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg using QuantGroupSize = remove_cvref_t; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); - static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); + // static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -204,14 +204,40 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg } else { - constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + // constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = cvt_scale_to_fp32(scale_reg); + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("scale_reg_f: %f, reg_offset: %d\n", scale_reg_f, reg_offset); + printf("nIter: %d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, " + "KPerBlockBQ: %d, kQScale: %d\n", + static_cast(nIter), + NWarp, + WG::kN, + static_cast(QuantGroupSize::kN), + static_cast(KPerBlockBQ), + static_cast(kQScale)); + } static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * scale_reg_f; + // if(get_block_id() == 0 && get_thread_id() == 0) { + // printf("acc_val: %f, scale_reg_f: %f\n", acc_val, scale_reg_f); + // } + c_ref = c_ref + acc_val * scale_reg_f; }); } }); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 08a0b04942..9de0ce3bf5 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -654,7 +654,10 @@ struct QuantGemmKernel (splitk_batch_offset.splitted_k / TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; - + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("kFlatN: %d, kFlatK: %d\n", kFlatN, kFlatK); + } return make_naive_tensor_view( b_ptr, make_tuple(kFlatN, kFlatK), @@ -989,10 +992,19 @@ struct QuantGemmKernel { static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("TilePartitioner::KPerBlock %d, QuantGroupSize::kK: %d, " + "TilePartitioner::NPerBlock %d, QuantGroupSize::kN: %d\n", + TilePartitioner::KPerBlock, + QuantGroupSize::kK, + TilePartitioner::NPerBlock, + QuantGroupSize::kN); + } return make_tile_window( bq_pad_view, - make_tuple(number{}, - number{}), + make_tuple(number{}, // 1 + number{}), // 16 {0, i_n / QuantGroupSize::kN}); } } @@ -1152,6 +1164,11 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(I3); + if(get_block_id() == 0 && get_thread_id() == 0) + { + bq_block_window.template print_tile_window_range( + 0, 1, 0, 32, "bq block window"); + } return GemmPipeline{}.template operator()(a_block_window, b_block_window, bq_block_window,