From c245d569d5423dba88e70cff77114e3509661e7b Mon Sep 17 00:00:00 2001 From: "BingYuan.Zhou" Date: Fri, 21 Mar 2025 13:01:14 +0800 Subject: [PATCH] fix ck_tile/basic_gemm build error (#1988) [ROCm/composable_kernel commit: 5a0d693b8648b48d9e2c30ac0a25d52e0d2c8969] --- example/ck_tile/03_gemm/gemm_basic.cpp | 0 .../03_gemm/script/benchmark_basic_bf16.sh | 14 ++++++++ .../03_gemm/script/benchmark_basic_bf8.sh | 14 ++++++++ .../ck_tile/03_gemm/script/run_full_test.sh | 3 ++ include/ck_tile/ops/common/utils.hpp | 1 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 35 ++++++++++++------- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 4 +-- 7 files changed, 57 insertions(+), 14 deletions(-) mode change 100644 => 100755 example/ck_tile/03_gemm/gemm_basic.cpp mode change 100644 => 100755 include/ck_tile/ops/common/utils.hpp mode change 100644 => 100755 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh index e69de29bb2..d7e5d4640a 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh index e69de29bb2..466f6bb4e1 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/run_full_test.sh b/example/ck_tile/03_gemm/script/run_full_test.sh index 2448acbad2..12ea6f0bf8 100755 --- a/example/ck_tile/03_gemm/script/run_full_test.sh +++ b/example/ck_tile/03_gemm/script/run_full_test.sh @@ -32,6 +32,9 @@ function print_log_header(){ } # run verification tests +for dtype in fp16 bf16 fp8 bf8; do + example/ck_tile/03_gemm/script/benchmark_basic_$dtype.sh +done example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh # run performance benchmarks diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp old mode 100644 new mode 100755 index 8592f93e0f..b422a0a896 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -18,6 +18,7 @@ template <> struct typeToStr { static constexpr const char * name = "bf1 template <> struct typeToStr { static constexpr const char * name = "fp8"; }; template <> struct typeToStr { static constexpr const char * name = "bf8"; }; template <> struct typeToStr { static constexpr const char * name = "int8"; }; +template <> struct typeToStr { static constexpr const char * name = "pk_int4"; }; // clang-format on template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 2a10389ce6..217408fffa 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -12,7 +12,7 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template +template struct GemmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; @@ -182,11 +182,11 @@ struct GemmPipelineAGmemBGmemCRegV1 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(std::is_same_v) + if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegBlockDistribution()); - shuffle_tile(a_shuffle_tmp, a_block_tile); + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp); } @@ -196,11 +196,11 @@ struct GemmPipelineAGmemBGmemCRegV1 } // LDS write 0 - if constexpr(std::is_same_v) + if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDistribution()); - shuffle_tile(b_shuffle_tmp, b_block_tile); + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp); } @@ -229,15 +229,26 @@ struct GemmPipelineAGmemBGmemCRegV1 move_tile_window(b_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, a_block_tile); + store_tile(a_copy_lds_window, + tile_elementwise_in(a_element_func, a_shuffle_tmp_loop)); + } + else + { + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } // LDS write i + 1 - if constexpr(std::is_same_v) + if constexpr(is_b_row_major) { auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDistribution()); - shuffle_tile(b_shuffle_tmp_loop, b_block_tile); + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, b_block_tile); store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp old mode 100644 new mode 100755 index c7115c8eb4..6bb14af9e6 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -129,7 +129,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t KPack = GetSmemPackA(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * M0)) + if constexpr(get_warp_size() >= (K2 * M0)) { constexpr index_t K1 = get_warp_size() / (K2 * M0); constexpr index_t K0 = BlockSize / get_warp_size(); @@ -219,7 +219,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t KPack = GetSmemPackB(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) + if constexpr(get_warp_size() >= (K2 * N0)) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = BlockSize / get_warp_size();