From c2d17ec24fbb028b6f8040a7e63be827890be8d2 Mon Sep 17 00:00:00 2001 From: solin Date: Wed, 14 May 2025 09:56:13 +0000 Subject: [PATCH] draft for 16*16*128 fp8 --- example/ck_tile/18_flatmm/CMakeLists.txt | 3 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 18 ++++++++-- example/ck_tile/18_flatmm/flatmm_basic.hpp | 17 ++++++++- .../ck_tile/18_flatmm/run_flatmm_example.inc | 35 ++++++------------- .../core/arch/amd_buffer_addressing.hpp | 8 ++--- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 2 ++ ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 25 +++++++++---- 7 files changed, 69 insertions(+), 39 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index f4d823e91a..b169de450b 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,6 +3,7 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 5f2c2a5aab..5159c8584b 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -27,7 +27,19 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr bool kPadK = false; constexpr int kBlockPerCu = 2; +#if defined(USING_MFMA_16x16x128) + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 256; + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 128; +#endif // This part comes from the Codegen #if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16) constexpr ck_tile::index_t M_Tile = 128; @@ -151,11 +163,11 @@ int run_flatmm_example(int argc, char* argv[]) { if(data_type == "fp16") { - run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + //run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else if(data_type == "bf16") { - run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + //run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else if(data_type == "fp8") { @@ -163,7 +175,7 @@ int run_flatmm_example(int argc, char* argv[]) } else if(data_type == "bf8") { - run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + //run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else { diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index bbce978724..f550f5a04d 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -31,6 +31,21 @@ #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif +template +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; +}; template struct GemmBasicTypeConfig; @@ -122,7 +137,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index c191fff7d0..e4ccd74385 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -33,37 +33,28 @@ static constexpr inline auto is_row_major(Layout layout_) // mfma_type, 0:32x32, 1:16x16 template -auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type) +auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + + if constexpr(GemmConfig::N_Warp_Tile == 32) { - ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + ck_tile::HostTensor t_view( + {n_ / 32, 32, k_ / GemmConfig::K_Warp_Tile, 2, GemmConfig::K_Warp_Tile / 2}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + else { - ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + static_assert(GemmConfig::N_Warp_Tile == 16); + ck_tile::HostTensor t_view( + {n_ / 16, 16, k_ / GemmConfig::K_Warp_Tile, 4, GemmConfig::K_Warp_Tile / 4}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) - { - ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) - { - ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - return t; } template @@ -189,12 +180,8 @@ int run_flatmm_example_with_layouts(int argc, // do pre-shuffle std::string mfma = arg_parser.get_str("prec"); -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) - ck_tile::index_t mfma_type = 1; -#else - ck_tile::index_t mfma_type = 0; -#endif - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type); + + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host); ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 5d6d6ce348..124af4586b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -38,10 +38,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); + // r.x = __builtin_amdgcn_readfirstlane(r.x); + // r.y = __builtin_amdgcn_readfirstlane(r.y); + // r.z = __builtin_amdgcn_readfirstlane(r.z); + // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index cbd20a6ea3..3074818bb5 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -75,6 +75,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { +#if 0 #if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16) constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); @@ -148,6 +149,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA }); __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA +#endif #endif } diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 1a1b729394..f071ce838b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -19,11 +19,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using namespace ck_tile; -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) +#if (defined(USING_MFMA_16x16x32) || defined(USING_MFMA_16x16x128)) && defined(ENABLE_FP8) + using ADataType = remove_cvref_t; /*reduce transform layers,compare with old ck*/ constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); + constexpr index_t KPack = Problem::VectorLoadSize / sizeof(ADataType); constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -138,6 +139,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); } + template + CK_TILE_HOST_DEVICE static constexpr auto GetK1() + { + using TileShape = typename Problem::BlockGemmShape; + if constexpr(TileShape::WarpTile::at(TileShape::idxN) == 32) + { + return TileShape::WarpTile::at(TileShape::idxK) / 2; + } + else + { + static_assert(TileShape::WarpTile::at(TileShape::idxN) == 16); + return TileShape::WarpTile::at(TileShape::idxK) / 4; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { @@ -232,16 +248,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { - using BDataType = remove_cvref_t; - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = - Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt + constexpr index_t KBPerLoad = GetK1(); // dwordx4 load B elem cnt constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = 1;