diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..d10a871652 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -146,6 +146,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, int main(int argc, char* argv[]) { - int result1 = !run_grouped_gemm_example(argc, argv); + int result1 = !run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..9b013063e9 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -101,6 +101,22 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfigComputeTurbo : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + template struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..671e4f35f8 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -259,15 +259,6 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - AQK = 0; // No A quantization - BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize - if(K % QuantGroupSize::kK != 0) - { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); - } - } stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); @@ -284,12 +275,6 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - stride_AQs[i] = 0; // No A quantization - stride_BQs[i] = - ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); - } a_m_k_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); @@ -311,13 +296,6 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); - bq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); - } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc @@ -349,16 +327,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_dev_buf.push_back( std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); - if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::HostTensor b_shuffle_host = - ck_tile::shuffle_b(b_k_n_tensors[i]); - b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); - } - else - { - b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); - } + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); @@ -444,17 +413,6 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::reference_gemm_quant( - a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); - } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -505,6 +463,19 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); + } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); @@ -541,13 +512,6 @@ int run_grouped_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } - else if(quant_mode == "bquant") - { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); - } else { throw std::runtime_error("Unsupported quantization mode!"); @@ -569,13 +533,6 @@ int run_grouped_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } - else if(quant_mode == "bquant") - { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); - } else { throw std::runtime_error("Unsupported quantization mode!");