diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index bcc0fcc044..40a4166126 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -11,9 +11,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_quant.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp - # gemm_bquant_quantgrouped_bf8i4.cpp - # gemm_bquant_quantgrouped_fp8i4.cpp - # gemm_bquant_quantgrouped_bf8.cpp + gemm_bquant_quantgrouped_bf8i4.cpp + gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb.cpp gemm_bquant_quantgrouped_preshufflequant.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp index f8667712d4..9f214fb34e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -33,19 +33,6 @@ void bquant_quantgrouped_preshuffleb_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", @@ -73,43 +60,158 @@ void bquant_quantgrouped_preshuffleb_instance_factory( ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - // lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", - // "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // using QuantGroupSize = ck_tile::QuantGroupShape>; - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; - // lut[hash_multiple_strings( - // {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // using QuantGroupSize = ck_tile::QuantGroupShape>; - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; - // lut[hash_multiple_strings( - // {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // using QuantGroupSize = ck_tile::QuantGroupShape>; - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp index 5d49bd46a7..e26edb6501 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp @@ -21,39 +21,37 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - // lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; - // lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] - // = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; - // lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] - // = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index 8771447eed..82967d5be2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -21,40 +21,39 @@ void bquant_quantgrouped_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; - // lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", - // "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; - // lut[hash_multiple_strings( - // {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; - // lut[hash_multiple_strings( - // {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - // [](const ck_tile::ArgParser& arg_parser) { - // using TypeConfig = decltype(GemmQuantTypeConfig{}); - // return run_gemm_example_prec_type, - // TypeConfig, - // QuantGroupSize, - // ck_tile::QuantType::BQuantGrouped>(arg_parser); - // }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index fd1b87cda2..216e3cdfb8 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -17,9 +17,9 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("h", "false", "Print help message") - .insert("m", "128", "m dimension") - .insert("n", "128", "n dimension") - .insert("k", "128", "k dimension") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row or Column") .insert("b_layout", "C", "B tensor data layout - Row or Column") .insert("bq_layout", "C", "Bq tensor data layout - Row or Column") @@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[]) "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " "or bf8i4") - .insert("warmup", "1", "Number of iterations before benchmarking the kernel") - .insert("repeat", "0", "Number of iterations to benchmark the kernel") + .insert("warmup", "50", "Number of iterations before benchmarking the kernel") + .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "SplitK value") .insert("device", "0", "Device id that will be used to run the kernel") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("flush_cache", "true", "Flush cache before running the kernel") - .insert("rotating_count", "0", "Rotating count") + .insert("rotating_count", "1000", "Rotating count") .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") .insert("preshuffleb", "false", "Enable preshuffle of tensor B") .insert("preshufflequant", "false", "Enable preshuffle of quant tensor") @@ -91,12 +91,12 @@ void aquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_fp8_instance_factory( std::unordered_map>& lut); -// void bquant_quantgrouped_bf8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_fp8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_bf8i4_instance_factory( -// std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_preshufflequant_instance_factory( @@ -125,9 +125,9 @@ int main(int argc, char* argv[]) aquant_quantgrouped_instance_factory(lut); aquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_fp8_instance_factory(lut); - // bquant_quantgrouped_bf8_instance_factory(lut); - // bquant_quantgrouped_fp8i4_instance_factory(lut); - // bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); bquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index ea04d83e29..9c1273069e 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -457,9 +457,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant) { - bq_tensor_ptr = - std::make_unique>(ck_tile::host_tensor_descriptor( - BQK, N / QuantGroupSize::kN, stride_BQ, is_row_major(bq_layout))); // 1x8 + bq_tensor_ptr = std::make_unique>( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { @@ -482,12 +481,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else { - ck_tile::FillUniformDistribution{-2.0f, - 3.0f /*, fill_seed(gen)*/}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); } - ck_tile::FillUniformDistribution{-2.0f, - 2.0f /*, fill_seed(gen)*/}(*bq_tensor_ptr); - ck_tile::FillUniformDistribution{-5.0f, 5.0f /*, fill_seed(gen)*/}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { @@ -524,30 +522,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillConstant{static_cast(0x38)}(a_m_k); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - // ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - if(bq_tensor_ptr) - { - BQDataType value = 1.0f; - for(int i = 0; i < BQK; i++) - { - for(int j = 0; j < N / QuantGroupSize::kN; j += (16 / QuantGroupSize::kN)) - { - for(int k = 0; k < 16 / QuantGroupSize::kN; k++) - { - (*bq_tensor_ptr)(i, j + k) = value; - } - value += static_cast(1.0f); - } - } - } - // for(int i = 0; i < BQK; i++) - // { - // for(int j = 0; j < N / QuantGroupSize::kN; j++) - // { - // printf("%.2f ", (*bq_tensor_ptr)(i, j)); - // } - // printf("\n"); - // } + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); } else { @@ -620,18 +595,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PreshuffleB) { - if constexpr(GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::TiledMMAPermuteN && + QuantGroupSize::kN == 1) // temporarily only for non-grouped quant { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); - printf("b_k_n_dev.get_lengths(): %lu, %lu, %lu, %lu, %lu, %lu, %lu\n", - b_k_n_dev.get_lengths()[0], - b_k_n_dev.get_lengths()[1], - b_k_n_dev.get_lengths()[2], - b_k_n_dev.get_lengths()[3], - b_k_n_dev.get_lengths()[4], - b_k_n_dev.get_lengths()[5], - b_k_n_dev.get_lengths()[6]); } else { @@ -653,47 +621,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) { - if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN && + QuantGroupSize::kN == 1) // temporarily only for non-grouped quant { - printf("Preshuffle BQ with TiledMMAPermuteN \n"); - for(int i = 0; i < static_cast((*bq_tensor_ptr).get_lengths()[0]); i++) - { - for(int j = 0; j < static_cast((*bq_tensor_ptr).get_lengths()[1]); j++) - { - printf("(*bq_tensor_ptr)[%d][%d]: %f\n", i, j, (*bq_tensor_ptr)(i, j)); - } - } + printf("PreshuffleBQuant with TiledMMAPermuteN\n"); ck_tile::HostTensor bq_permuted_host = ck_tile::bq_permuteN(*bq_tensor_ptr, QuantGroupSize::kN); - printf("bq_permuted_host.get_lengths(): %lu, %lu, %lu, %lu, %lu\n", - bq_permuted_host.get_lengths()[0], - bq_permuted_host.get_lengths()[1], - bq_permuted_host.get_lengths()[2], - bq_permuted_host.get_lengths()[3], - bq_permuted_host.get_lengths()[4]); - for(int i = 0; i < static_cast(bq_permuted_host.get_lengths()[0]); i++) - { - for(int j = 0; j < static_cast(bq_permuted_host.get_lengths()[1]); j++) - { - for(int k = 0; k < static_cast(bq_permuted_host.get_lengths()[2]); k++) - { - for(int l = 0; l < static_cast(bq_permuted_host.get_lengths()[3]); l++) - { - for(int m = 0; m < static_cast(bq_permuted_host.get_lengths()[4]); - m++) - { - printf("bq_permuted_host[%d][%d][%d][%d][%d]: %f\n", - i, - j, - k, - l, - m, - bq_permuted_host(i, j, k, l, m)); - } - } - } - } - } if constexpr(GemmConfig::PreshuffleQuant) { @@ -708,12 +641,16 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(GemmConfig::PreshuffleQuant) { + printf("PreshuffleBQuant without TiledMMAPermuteN\n"); ck_tile::HostTensor bq_shuffle_host = ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK); bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); } else + { + printf("No PreshuffleBQuant\n"); bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); + } } invoke_gemm; auto buf = tensor_view.template get_vectorized_elements(coord, 0); auto value = buf.at(number<0>{}); // Extract first element from thread buffer - printf(" %s[%d,%d] = %f\n", label, i, j, type_convert(value)); + printf(" %s[%d,%d] = %f", label, i, j, type_convert(value)); } printf("\n"); } diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index f8368eb2f9..1ebb5e4f5c 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -111,49 +111,17 @@ auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; // 128 - int bqk_ = t.get_lengths()[0]; // 1 x 128 - constexpr int NRepeat = - GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; // 128/16/4 = 2 + int n_ = t.get_lengths()[1]; + int bqk_ = t.get_lengths()[0]; + constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; ck_tile::HostTensor t_view({n_ / (GemmConfig::N_Tile / group_n), GemmConfig::N_Warp, GemmConfig::N_Warp_Tile / group_n, NRepeat, - bqk_}); //{1, 4, 16, 2, 1}, group_n:16 {1, 4, 1, 2, 1} + bqk_}); std::copy(t.begin(), t.end(), t_view.begin()); - printf("I am inside bq_permuteN\n"); - printf("t.get_lengths(): %lu, %lu, %lu, %lu, %lu\n", - t_view.get_lengths()[0], - t_view.get_lengths()[1], - t_view.get_lengths()[2], - t_view.get_lengths()[3], - t_view.get_lengths()[4]); - for(int i = 0; i < static_cast(t.get_lengths()[0]); i++) - { - for(int j = 0; j < static_cast(t_view.get_lengths()[1]); j++) - { - for(int k = 0; k < static_cast(t_view.get_lengths()[2]); k++) - { - for(int l = 0; l < static_cast(t_view.get_lengths()[3]); l++) - { - for(int m = 0; m < static_cast(t_view.get_lengths()[4]); m++) - { - printf("t_view[%d][%d][%d][%d][%d]: %f\n", - i, - j, - k, - l, - m, - t_view(i, j, k, l, m)); - } - } - } - } - } - printf("I am inside bq_permuteN\n"); - return ck_tile::reference_permute( - t_view, {0, 3, 1, 2, 4}); // {1, 2, 4, 16, 1}, group_n 16 {1, 2, 4, 1, 1} + return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); } template diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index b5646688a8..2e4f60d0b2 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg using QuantGroupSize = remove_cvref_t; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); - // static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -215,35 +214,13 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg return nIter * KPerBlockBQ + kQScale; } }(); - // constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = cvt_scale_to_fp32(scale_reg); - // if(get_block_id() == 0 && get_thread_id() == 1) - //{ - // printf("get_block_id(): %d, get_warp_id(): %d, get_thread_id(): %d, - // nIter: " - // "%d, NWarp: %d, WG::kN: %d, QuantGroupSize::kN: %d, " - // "KPerBlockBQ: %d, kQScale: %d, scale_reg_f: %f, reg_offset: %d\n", - // get_block_id(), - // get_warp_id(), - // get_thread_id(), - // static_cast(nIter), - // NWarp, - // WG::kN, - // static_cast(QuantGroupSize::kN), - // static_cast(KPerBlockBQ), - // static_cast(kQScale), - // scale_reg_f, - // reg_offset); - //} static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - // if(get_block_id() == 0 && get_thread_id() == 0) { - // printf("acc_val: %f, scale_reg_f: %f\n", acc_val, scale_reg_f); - // } - c_ref = c_ref + acc_val * scale_reg_f; + c_ref = c_ref + acc_val * scale_reg_f; }); } }); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 9cfdf38d20..4175dcff3f 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -653,10 +653,6 @@ struct QuantGemmKernel (splitk_batch_offset.splitted_k / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("kFlatN: %d, kFlatK: %d\n", kFlatN, kFlatK); - } return make_naive_tensor_view( b_ptr, make_tuple(kFlatN, kFlatK), @@ -991,20 +987,10 @@ struct QuantGemmKernel { static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - if(get_block_id() == 0 && get_thread_id() == 0) - { - printf("TilePartitioner::KPerBlock %d, QuantGroupSize::kK: %d, " - "TilePartitioner::NPerBlock %d, QuantGroupSize::kN: %d\n", - TilePartitioner::KPerBlock, - QuantGroupSize::kK, - TilePartitioner::NPerBlock, - QuantGroupSize::kN); - } return make_tile_window( bq_pad_view, - make_tuple(number{}, // 1 - number{}), // 128/16 = 8 + make_tuple(number{}, + number{}), {0, i_n / QuantGroupSize::kN}); } } @@ -1164,11 +1150,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(I3); - if(get_block_id() == 0 && get_thread_id() == 0) - { - bq_block_window.template print_tile_window_range( - 0, 1, 0, 128, "bq block window"); - } return GemmPipeline{}.template operator()(a_block_window, b_block_window, bq_block_window, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 9061090132..4f792e9de8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC tile_distribution_encoding_pattern_bq; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 6fc76e8694..6cd8dc3e0f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -169,9 +169,9 @@ struct tile_distribution_encoding_pattern_aq_transposed_c template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -255,18 +255,16 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding else if constexpr(XPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto XR = - XPerQ / WarpGemm::kN; // Scale replication factor //16/16 = 1 - constexpr auto X1 = NWarps / XR; // Warps per unique scale //4/1 = 4 - constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension //8/4 = 2 + constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto X1 = NWarps / XR; // Warps per unique scale + constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // 1, 1, 64 - tuple, sequence>, // 1, (2, 4) - tuple, sequence<0>>, //(1, 4, 1) (64) - tuple, sequence<2>>, - sequence<2, 1>, //(2, 1(in Y dimension)) - sequence<0, 0>>{}); + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); } else // XPerQ > WarpGemm::kN * NWarps {