diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index f4d823e91a..919641302f 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,6 +3,6 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 5f2c2a5aab..d5c0190d1a 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -16,6 +16,7 @@ template @@ -29,32 +30,19 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr int kBlockPerCu = 2; // This part comes from the Codegen -#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16) - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 128; + static_assert(sizeof(ADataType) == 2 || sizeof(ADataType) == 1); + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 4; - constexpr ck_tile::index_t K_Warp = 1; + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 16 : 32; - constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 16 : 32; - constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 64 : 16; + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; -#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 128; - - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 8; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 32 : 32; - constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; - constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; -#endif using CodegenFlatmmShape = ck_tile::TileFlatmmShape, ck_tile::sequence, @@ -134,6 +122,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con #include "run_flatmm_example.inc" +#if 0 int run_flatmm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -176,5 +165,5 @@ int run_flatmm_example(int argc, char* argv[]) } return -1; } - +#endif int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index bbce978724..a3691ec395 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -31,7 +31,133 @@ #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif -template +template +struct GemmConfig +{ + #if 0 + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = sizeof(DataType) == 2 ? 64 : 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64; + #endif + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = sizeof(DataType) == 2 ? 64 : 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64; +}; + +template <> +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + +template <> +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + +template <> +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + +template <> +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 8; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + +template <> +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 8; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + +template <> +struct GemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 8; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + + +template struct GemmBasicTypeConfig; template <> @@ -126,7 +252,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); + .insert("split_k", "1", "splitK value") + .insert("cfg_ver", "0", "gemm config version"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 15a9df2c0c..76180b9812 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -24,38 +24,28 @@ static constexpr inline auto is_row_major(Layout layout_) } // mfma_type, 0:32x32, 1:16x16 -template -auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type) +template +auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + if constexpr(GemmConfig::N_Warp_Tile == 32) { - ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + ck_tile::HostTensor t_view( + {n_ / 32, 32, k_ / GemmConfig::K_Warp_Tile, 2, GemmConfig::K_Warp_Tile / 2}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + else { - ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + static_assert(GemmConfig::N_Warp_Tile == 16); + ck_tile::HostTensor t_view( + {n_ / 16, 16, k_ / GemmConfig::K_Warp_Tile, 4, GemmConfig::K_Warp_Tile / 4}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) - { - ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) - { - ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - return t; } template @@ -83,6 +73,7 @@ template @@ -112,7 +103,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = flatmm_calc( + float ave_time = + flatmm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -129,7 +121,11 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, return ave_time; } -template @@ -143,11 +139,6 @@ int run_flatmm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; - ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -182,17 +173,11 @@ int run_flatmm_example_with_layouts(int argc, c_rslt_host.SetZero(); // do pre-shuffle - std::string mfma = arg_parser.get_str("prec"); -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) - ck_tile::index_t mfma_type = 1; -#else - ck_tile::index_t mfma_type = 0; -#endif - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host); ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); - invoke_flatmm( + invoke_flatmm( a_dev_buf, b_shuffle_dev_buf, c_dev_buf, @@ -219,8 +204,9 @@ int run_flatmm_example_with_layouts(int argc, a_host, b_origin_host, c_ref_host); const float max_accumulated_value = *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_rslt_host, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, c_ref_host, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -277,8 +263,9 @@ int run_flatmm_example_with_layouts(int argc, c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); const float max_accumulated_value = *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_rslt_host, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -292,3 +279,200 @@ int run_flatmm_example_with_layouts(int argc, return pass; } + +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + ck_tile::index_t cfg_ver = arg_parser.get_int("cfg_ver"); + + if(a_layout == "R" && b_layout == "C") + { + #if 0 + if(data_type == "fp16") + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + #endif + if(data_type == "fp8") + { + if (cfg_ver == 0) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if (cfg_ver == 1) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if (cfg_ver == 2) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if (cfg_ver == 3) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if (cfg_ver == 4) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if (cfg_ver == 5) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if (cfg_ver == 6) + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported config version for this operation !!!"); + } + } + #if 0 + else if(data_type == "bf8") + { + using Types = GemmBasicTypeConfig; + using ADataType = Types::ADataType; + using BDataType = Types::BDataType; + using AccDataType = Types::AccDataType; + using CDataType = Types::CDataType; + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + #endif + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 0b9956cd01..1bd1edd7b3 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -29,10 +29,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); + // r.x = __builtin_amdgcn_readfirstlane(r.x); + // r.y = __builtin_amdgcn_readfirstlane(r.y); + // r.z = __builtin_amdgcn_readfirstlane(r.z); + // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 2ff9d1ebf0..13740ba9eb 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -7,6 +7,8 @@ #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" +template struct Debug; + namespace ck_tile { template @@ -75,6 +77,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { +#if 0 constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -90,6 +93,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + + Debug xx0; + Debug> xx1; // constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num; #if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { @@ -147,6 +153,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA }); __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA +#endif + #endif } diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 474924ec84..d28ef49f77 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -49,20 +49,22 @@ struct UniversalFlatmmPipelineAgBgCrPolicy return a_lds_block_desc; #elif defined(USING_MFMA_32x32x16) + using ADataType = remove_cvref_t; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = GetSmemPackA(); + //constexpr index_t kKPack = GetSmemPackA(); + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), - number{}, + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * K1>{}, number{}, number<1>{}), + number{}, number<1>{}); constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / K1, K1))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -138,6 +140,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); } + template + CK_TILE_HOST_DEVICE static constexpr auto GetK1() + { + using TileShape = typename Problem::BlockGemmShape; + if constexpr(TileShape::WarpTile::at(TileShape::idxN) == 32) + { + return TileShape::WarpTile::at(TileShape::idxK) / 2; + } + else + { + static_assert(TileShape::WarpTile::at(TileShape::idxN) == 16); + return TileShape::WarpTile::at(TileShape::idxK) / 4; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { @@ -189,7 +206,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } else { - constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); constexpr index_t K0 = KPerBlock / K1; constexpr index_t M2 = get_warp_size() / K0; // coalesce reading for each blocks @@ -232,19 +249,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { - using BDataType = remove_cvref_t; - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = - Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt - constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KBPerLoad = GetK1(); + constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1;