From 7d669440a6a7b25ac539648ce77fe5a7ae87a657 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 23 Jun 2025 12:29:15 +0800 Subject: [PATCH 01/21] [CK_TILE] Fix compilation errors introduced in #2320, #2219 and #2214 (#2388) * Fix compilation errors * Fix more ck_tile example compilation errors --- example/ck_tile/02_layernorm2d/generate.py | 20 ++-- example/ck_tile/05_reduce/reduce.hpp | 2 +- example/ck_tile/10_rmsnorm2d/generate.py | 22 ++-- .../ck_tile/12_smoothquant/smoothquant.hpp | 20 ++-- .../14_moe_smoothquant/moe_smoothquant.hpp | 20 ++-- .../17_grouped_gemm/grouped_gemm_tileloop.cpp | 3 + .../core/tensor/tile_window_linear.hpp | 19 ++- .../fused_moe/kernel/fused_moegemm_shape.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 110 +++++++++--------- .../norm_reduce/block/block_norm_reduce.hpp | 4 +- 10 files changed, 112 insertions(+), 110 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 2dc9ccbd77..d77582630a 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp index 50ffb9c1c7..6fbb0b4274 100644 --- a/example/ck_tile/05_reduce/reduce.hpp +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -35,7 +35,7 @@ struct Reduce2dShape static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); static constexpr index_t BlockSize = - WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); }; template ; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); @@ -712,4 +712,4 @@ if __name__ == "__main__": if args.list_blobs: list_blobs(args) else: - gen_blobs(args) \ No newline at end of file + gen_blobs(args) diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp index 265399c276..5f8254a664 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.hpp +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -49,22 +49,22 @@ struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -72,13 +72,13 @@ struct smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index b29295f175..36cf477a42 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -38,22 +38,22 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -61,13 +61,13 @@ struct moe_smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp index 5c0cb92683..4107181520 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -116,9 +116,12 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 56c5066774..596584f3cc 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -314,8 +314,7 @@ struct tile_window_linear constexpr auto tile_dstr = typename Base::TileDstr{}; - auto dst_tensor = - make_static_distributed_tensor(tile_dstr); + auto dst_tensor = make_static_distributed_tensor(tile_dstr); auto issue = [&](auto i_access_) { constexpr auto IAccess = number{}; @@ -348,8 +347,9 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< - typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = + vec_value + .template get_as()[j / Base::Traits::PackedSize]; }); }; @@ -400,8 +400,9 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< - typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = + vec_value + .template get_as()[j / Base::Traits::PackedSize]; }); }; @@ -804,8 +805,7 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - vec_value.template get_as()( - j / Base::Traits::PackedSize) = + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -860,8 +860,7 @@ struct tile_window_linear constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Base::Traits::PackedSize; - vec_value.template get_as()( - j / Base::Traits::PackedSize) = + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp index 336bdc806f..92f6a48648 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp @@ -101,7 +101,7 @@ struct FusedMoeGemmShape static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; - static constexpr index_t BlockSize = WarpSize * NumWarps; + static constexpr index_t BlockSize = get_warp_size() * NumWarps; // some assert static_assert(Block_M0 == Block_M1); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 3e2e100025..5da675ae42 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -388,7 +388,7 @@ struct MoeSortingKernel } // reduce single pixel within a wave - template + template __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -625,7 +625,7 @@ struct MoeSortingKernel { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 - static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile; + static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile; { index_t eid = tid / experts_per_wave; index_t expert_offset = cumsum[eid] + @@ -693,7 +693,7 @@ struct MoeSortingKernel void* smem) const { const index_t tid = static_cast(threadIdx.x); - const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size()); const index_t lid = __lane_id(); constexpr index_t block_size = 256; // blockDim.x; const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; @@ -798,7 +798,7 @@ struct MoeSortingKernel // NOTE: under this block can never use __syncthreads! int i_e_ = 0; int local_cumsum_ = 0; - for(; i_e_ < num_experts; i_e_ += WarpSize) + for(; i_e_ < num_experts; i_e_ += get_warp_size()) { int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); int local_cnt = smem_cumsum(i_e_ + lid + 1); @@ -843,7 +843,7 @@ struct MoeSortingKernel // cumsum padded in case local cumsum is zero, but // pre_sumsum has value, which will result int // zero local cumsum(but we want at least padded) - wave_cumsum(local_cumsum_); + wave_cumsum(local_cumsum_); if((i_e_ + lid) < num_experts) smem_cumsum(i_e_ + lid + 1) = local_cumsum_; @@ -851,7 +851,7 @@ struct MoeSortingKernel if constexpr(Problem::LocalExpertMasking) { local_masking += pre_cumsum_masking; - wave_cumsum(local_masking); + wave_cumsum(local_masking); if((i_e_ + lid) < num_experts) smem_cumdup(i_e_ + lid + 1) = local_masking; } @@ -861,7 +861,7 @@ struct MoeSortingKernel // than 0(which is not we want) __builtin_amdgcn_s_waitcnt(0xc07f); } - if((lid + i_e_ - WarpSize) == (num_experts - 1)) + if((lid + i_e_ - get_warp_size()) == (num_experts - 1)) { *p_total_tokens_post_pad = local_cumsum_; } @@ -1109,7 +1109,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() return chunk * sizeof(index_t); }; -template +template CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -1504,7 +1504,7 @@ struct MoeSortingMultiPhaseKernel_P1 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return BLOCK_SIZE / WarpSize * sizeof(IndexType); + return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1546,8 +1546,8 @@ struct MoeSortingMultiPhaseKernel_P1 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % WarpSize; - index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % get_warp_size(); + index_t wave_id = threadIdx.x / get_warp_size(); // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P1 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) { c += s[i]; } @@ -1660,7 +1660,7 @@ struct MoeSortingMultiPhaseKernel_P01 // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { - return BLOCK_SIZE / WarpSize * sizeof(IndexType); + return BLOCK_SIZE / get_warp_size() * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1786,8 +1786,8 @@ struct MoeSortingMultiPhaseKernel_P01 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % WarpSize; - index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % get_warp_size(); + index_t wave_id = threadIdx.x / get_warp_size(); // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1801,7 +1801,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++) { c += s[i]; } @@ -1880,7 +1880,7 @@ struct MoeSortingMultiPhaseKernel_P2 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { // return 2 * BLOCK_SIZE * sizeof(IndexType); - return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1905,8 +1905,8 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / WarpSize; - index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / get_warp_size(); + index_t lane_id = threadIdx.x % get_warp_size(); IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -1951,22 +1951,22 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { - s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2083,7 +2083,7 @@ struct MoeSortingMultiPhaseKernel_P3 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType); + return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -2110,8 +2110,8 @@ struct MoeSortingMultiPhaseKernel_P3 } }(); int eid = blockIdx.x; - int wave_id = threadIdx.x / WarpSize; - int lane_id = threadIdx.x % WarpSize; + int wave_id = threadIdx.x / get_warp_size(); + int lane_id = threadIdx.x % get_warp_size(); int e_start = p_expert_cumsum[eid]; int e_end = p_expert_cumsum[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2141,17 +2141,17 @@ struct MoeSortingMultiPhaseKernel_P3 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2196,7 +2196,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) { constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 const index_t expert_cumsum_elem = num_experts_ + 1; - return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int); + return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int); } } // namespace impl @@ -2303,15 +2303,15 @@ struct MoeSortingMultiPhaseKernel_P23 const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); IndexType* p_total_tokens_post_pad = reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / WarpSize; - index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / get_warp_size(); + index_t lane_id = threadIdx.x % get_warp_size(); IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -2356,22 +2356,22 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { - s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2441,13 +2441,13 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType* s = reinterpret_cast(smem); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size(); const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); int eid = blockIdx.x; - int wave_id = threadIdx.x / WarpSize; - int lane_id = threadIdx.x % WarpSize; + int wave_id = threadIdx.x / get_warp_size(); + int lane_id = threadIdx.x % get_warp_size(); int e_start = p_expert_cumsum_smem[eid]; int e_end = p_expert_cumsum_smem[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2518,17 +2518,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2569,17 +2569,17 @@ struct MoeSortingMultiPhaseKernel_P23 cumsum_store += i_show[j]; }); int cumsum = cumsum_store; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2624,17 +2624,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk_1 = x1 - 1; // topk of this token int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not int cumsum = i_show_0 + i_show_1; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == WarpSize - 1) + if(lane_id == get_warp_size() - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 26437c7126..88da6be86e 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync // | w0 | w1 | w2 | w3 | -----> | w0123 | // // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); return num_warps * 4 * thread_buf_size * sizeof(float); } @@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); - constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); const index_t smem_offset = warp_id; // skip if nonthing to do From 61eb622e8590fc7d78aa183e437aec4c32977a66 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 23 Jun 2025 15:53:58 +0800 Subject: [PATCH 02/21] update the way to compute fmha fwd tflop, include mask type (#2386) * update the way to compute fwd tflop, include mask type Signed-off-by: JL-underdog * remove unneccessary comment * add necessary comment * remove some comment --------- Signed-off-by: JL-underdog Co-authored-by: root --- example/ck_tile/01_fmha/fmha_fwd.cpp | 4 ++-- example/ck_tile/01_fmha/mask.hpp | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) mode change 100644 => 100755 example/ck_tile/01_fmha/fmha_fwd.cpp mode change 100644 => 100755 example/ck_tile/01_fmha/mask.hpp diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp old mode 100644 new mode 100755 index bb1f495c4e..8958c0c96e --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -542,8 +542,8 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_k = real_seqlen_k; } - flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + - static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + + static_cast(2) * mask.get_unmaskarea() * hdim_v); num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + sizeof(KDataType) * real_seqlen_k * hdim_q + diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp old mode 100644 new mode 100755 index c77b700b16..b96482f535 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -21,6 +21,8 @@ enum class mask_enum struct mask_info { mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right @@ -42,6 +44,8 @@ struct mask_info ck_tile::index_t x_total = seqlen_k; ck_tile::index_t y_total = seqlen_q; mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; auto found_0 = str.find(':'); if(found_0 != std::string::npos) { @@ -148,7 +152,22 @@ struct mask_info } return tmp; } - + ck_tile::index_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return seqlen_q * seqlen_k; + ck_tile::index_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) { mi.serialize(os); From 47ae4b0955582432a667b713865f13ec48a634ed Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 23 Jun 2025 07:24:36 -0700 Subject: [PATCH 03/21] Shard several of the most costly targets. (#2373) * Shard several of the most costly targets. Introduces a filter_tuple_by_modulo to break up tuples. Drops build time of target from 21 minutes to under 14 minutes with 64 build processes, or 11 minutes with 128 build processes. time ninja -j 64 device_grouped_conv3d_fwd_instance * fix clang format * Fix build errors in instantiation code. I wasn't sure how to test the header-only instantiation code on my initial commit. From Jenkins CI test results, I see that there is a test target that depends on these headers: ninja -j 128 test_grouped_convnd_fwd This allowed me to test the build locally. I found three mistakes I made, mostly related to early experiments on I tried on the code. This was hard to find earlier because this PR is really too large. I also discovered that there are five 2D convolution targets that now dominate the compilation time. I will likely address those in a later PR, rather than adding even more changes to this PR. * Fix link errors from mismatched declarations. Our pattern for instantiating MIOpen templates uses duplicate declarations (instead of headers). This is fragile, and I didn't notice that my last commit had a bunch of link errors. I fixed these mistakes, and the bin/test_grouped_conv_fwd test target binary now links correctly. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Shard the longest 2D convolution builds Now that we have automated the shard instantiation, we can shard the 2D convolution targets that take the longest to build. The target test_grouped_conv2d_fwd now compiles in 15 minutes. * Use PROJECT_SOURCE_DIR for submodule compatibility I used CMAKE_SOURCE_DIR to refer to the top-level source directory in the ShardInstantiation.cmake file, but this can cause issues with git submodules. Instead, we should use PROJECT_SOURCE_DIR to ensure compatibility when this project is used as a submodule in another project. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Remove accidental copy of a file * Remove accidental copies of template files. --------- Co-authored-by: illsilin --- .gitignore | 3 + cmake/ShardInstantiation.cmake | 116 ++++++++++++++ cmake/call_shard.in | 15 ++ cmake/instantiate_shard.in | 9 ++ include/ck/utility/filter_tuple.hpp | 66 ++++++++ .../gpu/grouped_convolution_forward_xdl.inc | 3 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 51 ++++++- ...l_ngchw_gkcyx_ngkhw_bf16_comp_instance.in} | 38 ++--- ...wd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in} | 40 ++--- ...fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in} | 64 ++++---- ...gc_gkyxc_nhwgk_int8_mem_inter_instance.cpp | 66 -------- ...wgc_gkyxc_nhwgk_int8_mem_inter_instance.in | 80 ++++++++++ ...gc_gkyxc_nhwgk_int8_mem_intra_instance.cpp | 66 -------- ...wgc_gkyxc_nhwgk_int8_mem_intra_instance.in | 80 ++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 109 +++++++++++-- ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 111 -------------- ...ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in | 143 ++++++++++++++++++ ...ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp | 111 -------------- ..._ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in | 143 ++++++++++++++++++ ...gcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp | 54 ------- ...ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in | 65 ++++++++ ...ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp | 54 ------- ..._ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in | 63 ++++++++ ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 53 ------- ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in} | 53 ++++--- ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 53 ------- ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in} | 53 ++++--- ...w_gkczyx_ngkdhw_bf16_mem_inter_instance.in | 64 ++++++++ ...w_gkczyx_ngkdhw_bf16_mem_intra_instance.in | 65 ++++++++ ...w_gkczyx_ngkdhw_f16_mem_inter_instance.in} | 69 +++++---- ...w_gkczyx_ngkdhw_f16_mem_intra_instance.in} | 75 ++++----- ...w_gkczyx_ngkdhw_f32_mem_inter_instance.in} | 69 +++++---- ...w_gkczyx_ngkdhw_f32_mem_intra_instance.in} | 69 +++++---- 33 files changed, 1346 insertions(+), 827 deletions(-) create mode 100644 cmake/ShardInstantiation.cmake create mode 100644 cmake/call_shard.in create mode 100644 cmake/instantiate_shard.in create mode 100644 include/ck/utility/filter_tuple.hpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.in} (53%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in} (71%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in} (64%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in} (57%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in} (59%) diff --git a/.gitignore b/.gitignore index 599ef99e35..e4dd8f7513 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,6 @@ build*/ # Python cache __pycache__/ + +.cache/ + diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake new file mode 100644 index 0000000000..47a5d0c48c --- /dev/null +++ b/cmake/ShardInstantiation.cmake @@ -0,0 +1,116 @@ +# Function to generate templated instantiation functions and caller function. + +# In order to reduce build times, we split the instantiation of template functions into multiple files. +# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, +# which can be placed the TEMPLATE_FILE (typically a .in file). + +# This CMake function generates the instantiation functions and a caller function that calls all the instantiation +# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of +# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, +# and generates a caller function that calls all the instantiation functions. + +# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation +# of the template functions in the caller function, and that code is automatically generated by this function. + +# In addition to the user-supplied template, this CMake function uses two generic templates: +# +# 1. `instantiate_shard.in`: This is the template for the instantiation functions. +# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. + +# This function takes the following arguments: +# +# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). +# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. +# - NUM_SHARDS: The number of shards to generate. +# - OUTPUT_DIR: The build directory where the generated source files will be placed. +# - SRC_LIST: The list of source files to which the generated source files will be added. + + +function(generate_sharded_instantiations) + cmake_parse_arguments( + GEN_SHARDED + # No boolean arguments + "" + # Single-value arguments + "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" + # No multi-value arguments. + "" + ${ARGN} + ) + if (NOT GEN_SHARDED_INSTANCES_NAME) + message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_TEMPLATE_FILE) + message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_NUM_SHARDS) + message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") + endif() + if(NOT GEN_SHARDED_OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_SRC_LIST) + message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") + endif() + + file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) + + + set(GENERATED_SOURCE_FILES "") + set(EXTERN_TEMPLATE_STATEMENTS "") + set(CALL_STATEMENTS "") + message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + + set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") + + # Generate the inc file with the template function defintions. + # This include file will hold the template function definitions and a using alias for all the shard + # instantiation functions. + configure_file( + "${GEN_SHARDED_TEMPLATE_FILE}" + "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" + @ONLY + ) + + # Generate the sharded instantiation functions. + # This is where the build parallelization happens. + # Each of these source files will contain a single instantiation function for a shard, + # which will be called sequentially by the caller function. + set(INC_DIR "${GEN_SHARDED_INC_DIR}") + math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") + foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) + set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") + set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") + set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") + configure_file( + "${SHARD_FUNCTION_TEMPLATE}" + "${SHARD_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") + set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") + list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") + list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") + endforeach() + + # Join the include statements, the extern template declarations, and the call statements each + # into a single string for variable substitution in the caller function. + string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") + string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") + string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") + + # Generate the caller function. + set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") + set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") + configure_file( + "${FUNCTION_TEMPLATE}" + "${CALLER_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") + + # Add the generated source files to the list of source files. + # This allows the generated source files to be included in the build. + list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) + set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in new file mode 100644 index 0000000000..daba79b055 --- /dev/null +++ b/cmake/call_shard.in @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { + +@EXTERN_TEMPLATE_STATEMENTS@; + +void add_@INSTANCES@( + @INSTANCES@& instances) { +@CALL_STATEMENTS@; +} + +} // namespace ck::tensor_operation::device::instance diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in new file mode 100644 index 0000000000..dbc0af17a9 --- /dev/null +++ b/cmake/instantiate_shard.in @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { +template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( + @INSTANCES@& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/include/ck/utility/filter_tuple.hpp b/include/ck/utility/filter_tuple.hpp new file mode 100644 index 0000000000..c2e378b879 --- /dev/null +++ b/include/ck/utility/filter_tuple.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck::util { + +template +struct filter_tuple_by_modulo +{ + // Validate Stride and Offset. + static_assert(Stride > 0, "Offset must be positive."); + static_assert(Offset >= 0 && Offset < Stride, + "Offset must be positive and less than the stride."); + + // Generate filtered indices for this stride and offset. + static constexpr int new_size = (std::tuple_size_v + Stride - Offset - 1) / Stride; + + template + static constexpr auto to_index(std::index_sequence) + { + return std::index_sequence<(Offset + Is * Stride)...>{}; + } + + using filtered_indices = decltype(to_index(std::make_index_sequence{})); + + // Helper struct to construct the new tuple type from the filtered indices. + template + struct make_filtered_tuple_type_impl; + + template + struct make_filtered_tuple_type_impl> + { + using type = std::tuple...>; + }; + + using type = typename make_filtered_tuple_type_impl::type; +}; + +// Filter a tuple with a stride and offset. +// +// Tuple is a std::tuple or equivalent +// Stride is a positive integer +// Offset is a positive integer smaller than ofset +// +// Evaluates to a smaller tuple type from elements of T with stride M and offset I. +// +// Can be used to filter a tuple of types for sharded instantiations. +template +using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo::type; + +// Example compile-time test: +// using OriginalTuple = +// std::tuple; +// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t; +// static_assert(std::is_same_v>, +// "Test Case 1 Failed: Every 3rd from 2nd"); + +} // namespace ck::util diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index b018737932..a3f2515099 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -688,7 +688,6 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); - void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances_shard([[maybe_unused]] + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances& instances) { add_device_operation_instances( instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in similarity index 71% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in index 4ca1b2b85e..88c84adfe2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances = std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances_shard( + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwdDefault>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in index e3a12fd5f4..13fb583725 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances = std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances_shard( + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1P0>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1S1P0>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp deleted file mode 100644 index f667481fa4..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in new file mode 100644 index 0000000000..d8b35bda68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances = + std::vector>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances_shard( + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp deleted file mode 100644 index 2ff2c7f51f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in new file mode 100644 index 0000000000..125e16139d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances = + std::vector>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances_shard( + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index f8efa5a7c1..1d9d75a104 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -11,8 +11,6 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -32,23 +30,13 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp +xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp @@ -71,6 +59,99 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) +# Add generated files for sharded instantiations. +include(ShardInstantiation) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in + NUM_SHARDS 8 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl +) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in + NUM_SHARDS 8 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp deleted file mode 100644 index a94f687ef8..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in new file mode 100644 index 0000000000..9d0eba6a6c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances = + std::vector>>; + +template +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp deleted file mode 100644 index 0c63345e7f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in new file mode 100644 index 0000000000..ccabc2090a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances = + std::vector>>; + +template +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + } +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp deleted file mode 100644 index 43241454a5..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in new file mode 100644 index 0000000000..4c67e4912c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp deleted file mode 100644 index d02d9f6778..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in new file mode 100644 index 0000000000..0fbefa3bbc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp deleted file mode 100644 index 060eebebc1..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in index f3eccc7dc8..c87783eed9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp deleted file mode 100644 index 85b088f416..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in index abea0bea81..ca6d571be1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in new file mode 100644 index 0000000000..2586bc0f16 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in new file mode 100644 index 0000000000..7405f86a5f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in index ba5d9fb1de..24d6b66976 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in similarity index 57% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in index fac3098341..91a2444241 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in @@ -3,53 +3,60 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in index 5a2c4a0d5b..7571dff883 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in index 701b8eb4a4..38ed240fab 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance From dbfe70e72a5f2f0317b715cd4c7f7fb662affbe5 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 23 Jun 2025 09:31:46 -0500 Subject: [PATCH 04/21] Add accelerated stochastic rounding on gfx950 (#2355) * Add native prand generation support for gfx950 * Update seed calculation --- include/ck/utility/amd_ck_fp8.hpp | 65 +++++++++++++--- include/ck/utility/mxf8_utils.hpp | 10 ++- include/ck/utility/type_convert.hpp | 114 ++++++++++++++++++---------- 3 files changed, 134 insertions(+), 55 deletions(-) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index d079639c6a..cdc2a4fbda 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" #include "ck/utility/enable_if.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/functional.hpp" #include "ck/utility/type.hpp" @@ -1396,12 +1397,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } return cast_to_f8_from_f32( f, rng); @@ -1416,12 +1423,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) @@ -1487,12 +1500,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f[0]); + rng = prand_generator(reinterpret_cast(&f), f[0]); #else rng = prand_generator(reinterpret_cast(&f), f[0]); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } return cast_to_f8_from_f32( f, rng); @@ -1532,12 +1551,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC rng = prand_generator(reinterpret_cast(&x), x); #else rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_f16(reinterpret_cast(&x), x[0]); #else rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_f16(reinterpret_cast(&x), static_cast(x)); #else rng = prand_generator(reinterpret_cast(&x), static_cast(x)); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_bf16(reinterpret_cast(&x), + rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); #else rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_bf16(reinterpret_cast(&f), f); + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); } return cast_to_f8_from_f32_scaled(f, rng, scale); } @@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const uint32_t rng = 0; if constexpr(stochastic_rounding) { - constexpr int seed = 1254739; - rng = prand_generator(reinterpret_cast(&f), f[0]); + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); } return cast_to_f8_from_f32_scaled(f, rng, scale); } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 5865f1dd78..2208a73860 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/f8_utils.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/mxf4_utils.hpp" #include "ck/utility/mxf6_utils.hpp" #include "ck/utility/random_gen.hpp" @@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) #if defined(__gfx94__) union { @@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) #if defined(__gfx94__) union { @@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 // convert fp32 to fp4 with stochastic rounding inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { uint32_t bitwise; @@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) value.bitwise, float_values.float2_array, rng, scale, 0); return value.f4_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -1475,13 +1491,10 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) // convert vector of 2 fp32 to vector of 2 fp4 with sr inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { uint32_t bitwise; @@ -1499,6 +1512,12 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) #endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION return value.f4x2_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif union { uint32_t bitwise; @@ -1514,13 +1533,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) // convert vector of 32 fp32 to vector of 32 fp4 with sr inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { __uint128_t bitwise; @@ -1546,6 +1562,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f return f4_values.f4x32_array; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif union { __uint128_t bitwise; @@ -1776,13 +1798,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 */ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { float32_t float_vector; @@ -1799,6 +1818,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) return out.f6_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -1815,6 +1840,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) */ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); + return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); +#else constexpr int seed = 1254739; union { @@ -1828,9 +1859,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif -#if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); -#else union { float32_t float_vector; @@ -2044,13 +2072,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 */ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { float32_t float_vector; @@ -2067,6 +2092,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) return out.bf6_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -2085,6 +2116,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) */ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); + return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); +#else constexpr int seed = 1254739; union { @@ -2098,9 +2135,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1. uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif -#if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); -#else union { float32_t float_vector; From b8212864cf569b347f26816bfd44a50cadd60e28 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 24 Jun 2025 01:33:31 +0800 Subject: [PATCH 05/21] [CK_TILE] FMHA Support hdim_v to as a Multiple of 32 (#2114) * 160+192 * Add splitkv d160 * cleanup * fix * Add change log * Fix CHANGELOG * Use static_cast * Update ignored instance --------- Co-authored-by: asleepzzz --- CHANGELOG.md | 1 + example/ck_tile/01_fmha/README.md | 1 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 45 +++++++-------- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 5 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 43 ++++----------- include/ck_tile/core/tensor/shuffle_tile.hpp | 7 ++- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 55 +++++++++++++++---- 7 files changed, 89 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 368d1e502d..ab2076c0d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) * Added benchmarking support for tile engine GEMM. * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 12414a20ed..72109a660b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -71,6 +71,7 @@ args: -drop_seed seed for random number generator (default:1) -drop_offset offset for random number generator (default:0) -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) ``` diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 7cbbdb9034..37a1b7329b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -282,18 +282,19 @@ class FmhaFwdApiPool: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.pool[trait.dtype][hdim].append(copy.copy(trait)) @property def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -306,7 +307,7 @@ class FmhaFwdApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -435,18 +436,20 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1), + (192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1), + (256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -454,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! @@ -463,7 +466,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - if hdim == 256: + if hdim == 256 and hdim_v == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) @@ -507,15 +510,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()): + for pipeline in get_pipelines(dtype, hdim, hdim_v): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if hdim == 192 and tile.F_bn1 == 128: + if (hdim, hdim_v) == (192, 128) or hdim == 160: # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3ae0e28be3..2d2d71555d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + # 160: 160, 256: 256 } @@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d '64' : FmhaFwdSplitKVCombineTileSize(32, -1), ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if dtype in ['fp16', 'bf16']: for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): # TODO: use async pipeline when compiler is more stable - if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: + if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]: # if True: pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8958c0c96e..972653c218 100755 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -178,50 +178,30 @@ auto get_elimit(std::string init_method) } } -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if(batch_nhead_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + max_splits = std::min({max_splits, num_SMs}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || - ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) { - efficiency.push_back(0.f); - } - else - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); + max_efficiency = eff; } + efficiency.push_back(eff); } for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) - { - continue; - } if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); @@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int override_num_splits_if_necessary( int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { + (void)hdim_v; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) @@ -250,15 +231,13 @@ int override_num_splits_if_necessary( // tile size should match the generate.py const int kM0 = 64; - const int kN1 = hdim_v; const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); if(num_splits < 1 && p_drop == 0.0f) { return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); } return num_splits; diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 55e3274cde..84c2b7d2fa 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT // set output vectors static_for<0, num_vec_out, 1>{}([&](auto i) { constexpr auto idx_y_out_tmp = generate_array( - [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, + [&](auto ii) { + return ii == y_dim_vec_in ? static_cast(idx_y_start[ii]) + i + : static_cast(idx_y_start[ii]); + }, number{}); constexpr auto idx_y_out = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 30d07a4754..0b8e5836cd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); - static_assert(kKPack % K3 == 0); + constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K3 = total_pixels / N1; constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - if constexpr(get_warp_size() % (K2 * N0) == 0) + if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible + { + constexpr index_t kNPack = 32; + static_assert(kNPerBlock % kNPack == 0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = 2; + constexpr index_t N1_m = kNPack / N2; + constexpr index_t N0_m = kNPerBlock / kNPack; + constexpr index_t K1 = get_warp_size() / N1_m; + constexpr index_t K2_m = kKPerBlock / K1; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2, 1>>, // K0, K1 N0 + tuple, sequence<1, 1>>, + sequence<1, 2, 1>, // N0 K2 N2 + sequence<0, 2, 2>>{}); + } + else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); @@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); constexpr index_t N0 = kNPerBlock / N1; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - if constexpr(get_warp_size() % (K2 * N0) == 0) + if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible + { + constexpr index_t kNPack = 32; + static_assert(kNPerBlock % kNPack == 0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = 2; + constexpr index_t N1_m = kNPack / N2; + constexpr index_t N0_m = kNPerBlock / kNPack; + constexpr index_t K1 = get_warp_size() / N1_m; + constexpr index_t K2_m = kKPerBlock / K1; + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, // K0, K1 N0 + tuple, sequence<1, 1>>, + sequence<1, 1, 2>, // N0 K2 <-> N2 + sequence<0, 2, 2>>{}); + } + else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); From bb571a033019fd5a8ba6de31119395c3621a4235 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 24 Jun 2025 14:51:29 +0800 Subject: [PATCH 06/21] fix moe i4 bug from aiter (#2339) --- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 4f7b8e768c..29750b8baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -122,7 +122,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< using Base::B_K1; using Base::I0; using Base::I1; - using Base::KGroup; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -154,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack / KGroup; + constexpr index_t K2 = KPack; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat * KGroup; + constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -291,14 +290,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< block_sync_lds(); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -391,15 +388,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -483,14 +477,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -596,7 +588,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), - Sequence<1, 1, 1, 1, 1, KPack / KGroup>, + Sequence<1, 1, 1, 1, 1, KPack>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, From 9e74ae7c8955c2f7f42c8b49bb6c0d01878e671d Mon Sep 17 00:00:00 2001 From: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com> Date: Tue, 24 Jun 2025 16:28:13 +0200 Subject: [PATCH 07/21] Implement batched gemm wmma (RDNA batched gemm) based on wmma cshuffle v3 (#2319) * Some prep work for adding batched_gemm_wmma_universal. Moved batched_gemm in general to gfx11 and gfx12 categories, and split existing batched_gemm test into xdl and wmma versions. Updated profiler and instance factory. For now only adding f16-row-row-row-GemmDefault. For now actual device instance list is empty. * Add DeviceBatchedGemm_Wmma_CShuffleV3 based on DeviceGemm_Wmma_CShuffleV3 and make sure it's used in the instance factory and tests. Currently the new batched device level struct cannot actually handle batching, but it does pass tests with a trivial batch size of 1, meaning that the overall structure is good. * Add custom kernel and Argument type to DeviceBatchedGemm_Wmma_CShuffleV3. Batching arguments not passed to kernel yet. * Implement kernel-level batching logic for DeviceBatchedGemm_Wmma_CShuffleV3. In principle the whole thing works now, just need to add other data types and perhaps do some cleanup. * Add other layouts for batched gemm wmma chufflev3 f16 f16 f16. Now matching XDL (for f16). * Add bf16 bf16 bf16 support for batched gemm wmma cshuffle v3 for all layouts. * Fixup comments and TODOs * Expand test cases for batched gemm wmma cshuffle v3 with more unusual shapes. Some of the original test cases for batched gemm do not work based on cshuffle v3 because the dimensions are too small. * Fix argument order for calls to profile_batched_gemm_impl() ONLY in wmma tests. * Take batching into account when using rotating memory or clearing the C tensor. * Implement small refactors / comments etc. from review. * Port recent gemm wmma updates to batched gemm wmma: V1 pipeline, non-main-k-block-loop, check compute type, packed buffer size calc. Ported new instance lists. * Add MNKPadding instances to batched gemm wmma cshuffle v3, remove incompatible test problems. * Put clearing the C matrix in a pre-process lambda for the non-flush case + small fixups. * Once again switch order of strides and batch strides in calls to profile_batched_gemm_impl() from test_batched_gemm_wmma to match latest definition of that function. --------- Co-authored-by: kiefer --- .../device_batched_gemm_wmma_cshuffle_v3.hpp | 759 ++++++++++++++++++ .../gpu/batched_gemm.hpp | 105 ++- .../gpu/batched_gemm/CMakeLists.txt | 42 +- ...al_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp | 71 ++ ...al_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp | 73 ++ ...al_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp | 76 ++ ...al_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp | 79 ++ ...ersal_f16_f16_f16_gkm_gkn_gmn_instance.cpp | 70 ++ ...ersal_f16_f16_f16_gkm_gnk_gmn_instance.cpp | 72 ++ ...ersal_f16_f16_f16_gmk_gkn_gmn_instance.cpp | 75 ++ ...ersal_f16_f16_f16_gmk_gnk_gmn_instance.cpp | 78 ++ profiler/src/CMakeLists.txt | 4 +- test/batched_gemm/CMakeLists.txt | 9 +- test/batched_gemm/test_batched_gemm_wmma.cpp | 193 +++++ 14 files changed, 1684 insertions(+), 22 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instance.cpp create mode 100644 test/batched_gemm/test_batched_gemm_wmma.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..580a47de14 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,759 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument + karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = batch; + ignore = compute_ptr_offset_of_batch; +#endif +} + +/// @brief \"Universal\" Batched GEMM operation without SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through its design +/// and versatilty. +/// +/// @note This Kernel implementation currently does not support the SplitK algorithm. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). Currently not supported! +template +struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm +{ + // We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and + // permuteB arguments so for now we are not including this functionality. + static_assert(PermuteA == false, + "Permute A functionality not supported by DeviceBatchedGemm operations.\n"); + static_assert(PermuteB == false, + "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by DeviceBatchedGemm base class. + false>; // PermuteB not supported by DeviceBatchedGemm base class. + + // Argument + struct Argument : public GridwiseGemm::Argument + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t Batch_, + index_t k_batch_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} + { + } + + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + // The normal approach to batching would be to increase the grid size by just stretching + // out the grid Z dimension (which is the outermost dimension), but this depends on + // lower level functions not directly using the Z dimension for other calculations. As + // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. + // Therefore, for now we will use the grid Y dimension for batching. This may be a bit + // fragile. + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + // Note: the grid descriptors and size_a / size_b do *not* take batching into + // account, so we have to manually multiply overall buffer sizes for rotating + // memory by batch. + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.Batch * size_a_buffer, + arg_.Batch * size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + auto clear_workspace = [&]() { + // clear c mem + if(arg.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_c_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // TODO: This is not part of the DeviceBatchedGemm base class but it was part of + // DeviceBatchedGemmV2. Remove? + // index_t GetKPerBlock() override { return KPerBlock; } + // bool GetPermuteA() override { return PermuteA; } + // bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1 /* KBatch */}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1); // KBatch + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemm_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"<>>& + instances); +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances); +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances); +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances); +#endif // CK_ENABLE_FP16 +#ifdef CK_ENABLE_BF16 +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances); +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances); +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances); +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances); +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( std::vector>>& instances); #endif +#endif // CK_USE_XDL + template > op_ptrs; + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances( + op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) @@ -258,6 +360,7 @@ struct DeviceOperationInstanceFactory +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..8ead225c7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..f9e0f610fa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..41ed9bfb3b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..21fee6f321 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..ea9b725286 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..fc0fc45887 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..e67df2cada --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = + std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances{}); + add_device_operation_instances( + instances, + device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index fef09315d5..1e65e9e580 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -67,7 +67,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") list(APPEND PROFILER_OPS profile_gemm_mx.cpp) endif() - list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) @@ -92,6 +91,7 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) + list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) @@ -164,7 +164,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) endif() - list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") @@ -206,6 +205,7 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) + list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 759cf3da67..4c325b2872 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,4 +1,9 @@ -add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp) +add_gtest_executable(test_batched_gemm_xdl test_batched_gemm_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) + target_link_libraries(test_batched_gemm_xdl PRIVATE utility device_batched_gemm_instance) +endif() + +add_gtest_executable(test_batched_gemm_wmma test_batched_gemm_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_wmma PRIVATE utility device_batched_gemm_instance) endif() diff --git a/test/batched_gemm/test_batched_gemm_wmma.cpp b/test/batched_gemm/test_batched_gemm_wmma.cpp new file mode 100644 index 0000000000..18f9db8c39 --- /dev/null +++ b/test/batched_gemm/test_batched_gemm_wmma.cpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_batched_gemm_impl.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + +struct GemmParams +{ + ck::index_t M; + ck::index_t N; + ck::index_t K; + ck::index_t BatchCount; +}; + +class TestBatchedGemm : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + std::vector params; + + template + void Run() + { + using namespace ck::tensor_operation::device; + + bool pass = true; + for(auto& param : params) + { + const auto M = param.M; + const auto N = param.N; + const auto K = param.K; + const auto BatchCount = param.BatchCount; + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = + pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + } + EXPECT_TRUE(pass); + } +}; + +// #ifdef CK_ENABLE_INT8 +// TEST_F(TestBatchedGemm, i8) +// { +// this->params.push_back({64, 64, 64, 2}); +// this->params.push_back({64, 64, 64, 1}); +// this->params.push_back({60, 60, 60, 2}); +// this->params.push_back({68, 68, 68, 2}); +// this->params.push_back({40, 40, 40, 2}); +// this->params.push_back({256, 256, 128, 3}); +// this->template Run(); +// } +// #endif + +#ifdef CK_ENABLE_BF16 +TEST_F(TestBatchedGemm, bf16) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + + // Tests with larger MNK + this->params.push_back({512, 256, 128, 1}); + this->params.push_back({256, 240, 192, 2}); + this->params.push_back({256, 256, 128, 3}); + this->params.push_back({240, 128, 128, 5}); + this->template Run(); +} +#endif + +#ifdef CK_ENABLE_FP16 +TEST_F(TestBatchedGemm, fp16) +{ + this->params.push_back({64, 64, 64, 2}); + this->params.push_back({64, 64, 64, 1}); + this->params.push_back({40, 40, 40, 2}); + this->params.push_back({256, 256, 128, 3}); + + // Tests with larger MNK + this->params.push_back({512, 256, 128, 1}); + this->params.push_back({256, 240, 192, 2}); + this->params.push_back({256, 256, 128, 3}); + this->params.push_back({240, 128, 128, 5}); + this->template Run(); +} +#endif + +// #ifdef CK_ENABLE_FP32 +// TEST_F(TestBatchedGemm, fp32) +// { +// this->params.push_back({64, 64, 64, 2}); +// this->params.push_back({64, 64, 64, 1}); +// this->params.push_back({60, 60, 60, 2}); +// this->params.push_back({68, 68, 68, 2}); +// this->params.push_back({40, 40, 40, 2}); +// this->params.push_back({256, 256, 128, 3}); +// this->template Run(); +// } +// #endif From 42e246e90fa42d7dd745b9e843c62f4d90540af8 Mon Sep 17 00:00:00 2001 From: JonathanLichtnerAMD <195780826+JonathanLichtnerAMD@users.noreply.github.com> Date: Tue, 24 Jun 2025 08:30:42 -0600 Subject: [PATCH 08/21] Fix build error when building with MIOPEN_REQ_LIBS_ONLY=ON (#2383) Co-authored-by: John Shumway --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b0fc725236..6e032a30cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -634,7 +634,7 @@ option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) add_subdirectory(library) -if(NOT GPU_ARCHS AND USER_GPU_TARGETS) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name From 87fdb368a73f1c21c2f556e87981801224c958ef Mon Sep 17 00:00:00 2001 From: JonathanLichtnerAMD <195780826+JonathanLichtnerAMD@users.noreply.github.com> Date: Tue, 24 Jun 2025 08:32:16 -0600 Subject: [PATCH 09/21] Do not build "other" library for MIOpen (#2382) MIOpen only needs the static CK library for convolutions. --- library/src/tensor_operation_instance/gpu/CMakeLists.txt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index dbd503c0bd..aea3359aff 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -295,11 +295,8 @@ FOREACH(subdir_path ${dir_list}) if(MIOPEN_REQ_LIBS_ONLY) message(STATUS "Removing all sources that are not required for MIOpen") - if("${cmake_instance}" MATCHES "gemm" OR - "${cmake_instance}" MATCHES "mha" OR - "${cmake_instance}" MATCHES "contraction" OR - "${cmake_instance}" MATCHES "reduce") - set(add_inst 0) + if(NOT "${cmake_instance}" MATCHES "conv") + set(add_inst 0) endif() endif() @@ -328,7 +325,7 @@ ENDFOREACH() -if(CK_DEVICE_OTHER_INSTANCES) +if(CK_DEVICE_OTHER_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) add_library(device_other_operations ${CK_DEVICE_OTHER_INSTANCES}) add_library(composablekernels::device_other_operations ALIAS device_other_operations) set_target_properties(device_other_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) From 77123600ee4b6fae077a2145b68b00a8b2ce9460 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Tue, 24 Jun 2025 20:45:24 +0600 Subject: [PATCH 10/21] Improve fmha_bwd tests performance (#2376) * Avoid passing indices (std::vector) by value to host tensor's operator() Each access requires 2 allocations and copies of the vector. * Remove 1 unneeded vector copy from the slowest part of fmha_bwd's verification * Compute ds_hp_host_ref in parallel This sequntial ForEach is the slowest part of validation and it benefits from parallel computation. * Do not use ForEach for simple copy and conversion of large tensors These tensors all have the same shape {nhead, real_seqlen_q, real_seqlen_k} and can be copied/converted without complex computations of linear indices. --- example/ck_tile/01_fmha/fmha_bwd.cpp | 47 +++++++++------------- include/ck/library/utility/host_tensor.hpp | 6 +-- include/ck_tile/host/host_tensor.hpp | 9 +++-- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index eaf99529f3..3b9cf09eb2 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_bwd.hpp" #include "ck_tile/host.hpp" @@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(p_drop > 0) { - p_hp_host_ref.ForEach( - [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + p_dropped_hp_host_ref = p_hp_host_ref; randval_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); }); ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); } else { - p_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); } // O = P * V @@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser) } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { - AccDataType do_dot_o = 0; - for(int o = 0; o < hdim_v; o++) - { - auto idx_gmo = idx_gmn; - idx_gmo[2] = o; - do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * - ck_tile::type_convert(o_host_refs[wb](idx_gmo)); - } - self(idx_gmn) = ck_tile::type_convert( - p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); - }); + ck_tile::make_ParallelTensorFunctor( + [&](auto i0, auto i1, auto i2) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + } + ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( + p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + }, + ds_hp_host_ref.mDesc.get_lengths()[0], + ds_hp_host_ref.mDesc.get_lengths()[1], + ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); if(use_dbias) { - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + dbias_host_ref = ds_hp_host_ref.template CopyAsType(); } - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 06e33afd20..286dffc36c 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -167,7 +167,7 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } - std::size_t GetOffsetFromMultiIndex(std::vector iss) const + std::size_t GetOffsetFromMultiIndex(const std::vector& iss) const { return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } @@ -600,12 +600,12 @@ struct Tensor ck::packed_size_v>]; } - T& operator()(std::vector idx) + T& operator()(const std::vector& idx) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } - const T& operator()(std::vector idx) const + const T& operator()(const std::vector& idx) const { return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index deaa158d50..b8c764809c 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -230,7 +230,7 @@ struct HostTensorDescriptor * @param iss Vector containing the multi-dimensional indices * @return The calculated linear offset as a size_t */ - std::size_t GetOffsetFromMultiIndex(std::vector iss) const + std::size_t GetOffsetFromMultiIndex(const std::vector& iss) const { return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } @@ -540,9 +540,12 @@ struct HostTensor return mData[GetOffsetFromMultiIndex(is...)]; } - T& operator()(std::vector idx) { return mData[GetOffsetFromMultiIndex(idx)]; } + T& operator()(const std::vector& idx) + { + return mData[GetOffsetFromMultiIndex(idx)]; + } - const T& operator()(std::vector idx) const + const T& operator()(const std::vector& idx) const { return mData[GetOffsetFromMultiIndex(idx)]; } From 778ac24376813d18e63c9f77a2dd51cf87eb4a80 Mon Sep 17 00:00:00 2001 From: JiaLuo-CAN Date: Tue, 24 Jun 2025 12:13:18 -0400 Subject: [PATCH 11/21] add a mx_fp8 client example (#2380) * add a mx_fp8 client example * remove verify code and fix date * remove verify code and fix date, type --------- Co-authored-by: root Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: Andriy Roshchenko --- client_example/32_gemm_mx/CMakeLists.txt | 4 + client_example/32_gemm_mx/gemm_mx_fp8.cpp | 330 ++++++++++++++++++++++ client_example/README.md | 2 + 3 files changed, 336 insertions(+) create mode 100644 client_example/32_gemm_mx/CMakeLists.txt create mode 100644 client_example/32_gemm_mx/gemm_mx_fp8.cpp diff --git a/client_example/32_gemm_mx/CMakeLists.txt b/client_example/32_gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..558986bf5a --- /dev/null +++ b/client_example/32_gemm_mx/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx950") + add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp) + target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/32_gemm_mx/gemm_mx_fp8.cpp b/client_example/32_gemm_mx/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..6e14bf2a5f --- /dev/null +++ b/client_example/32_gemm_mx/gemm_mx_fp8.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::half_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; +template +inline constexpr bool is_same_v = ck::is_same::value; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AScaleLayout = Row; +using BScaleLayout = Col; + +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + mem_size_ = mem_size; + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mem_size_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; + + /* Require by mx type*/ + constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + /* Scale stride Calculation */ + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + return static_cast(col); + else + return static_cast(row); + } + else + return static_cast(stride); + }; + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize; + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem a_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + SimpleDeviceMem b_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemmMX; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + + sizeof(XDataType) * N * K / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/README.md b/client_example/README.md index d9f793434d..34c6733d05 100644 --- a/client_example/README.md +++ b/client_example/README.md @@ -14,8 +14,10 @@ cd client_example/build cmake \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \ +-D GPU_TARGETS="gfx908;gfx90a" \ .. ``` +You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s). ### Build client example ```bash From c5d9181e1bd8c64110941e244b3d3e1e6c5f6385 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Wed, 25 Jun 2025 07:35:54 +0800 Subject: [PATCH 12/21] Fix unmatched K size of WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution on gfx950 (#2393) --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index be5d5690ff..f243aceda8 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -172,7 +172,7 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32>>; #else using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl Date: Tue, 24 Jun 2025 21:46:15 -0700 Subject: [PATCH 13/21] Fix amd_ck_fp8.hpp macro definitions (#2325) * Fix amd_ck_fp8.hpp macro definitions 1. Define CK_USE_FNUZ_FP8 and CK_USE_OCP_FP8 definitions only if they were not defined before. 2. Prefix __assert_fnuz_support and __assert_ocp_support with namespace fp8_impl to avoid redefined error when building with rocm 6.4+ (rocm/6.4.0/include/hip/amd_detail/amd_hip_fp8.h) Co-authored-by: Andriy Roshchenko --- include/ck/utility/amd_ck_fp8.hpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index cdc2a4fbda..b7af32d3dc 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -10,15 +10,11 @@ #include "ck/utility/functional.hpp" #include "ck/utility/type.hpp" -#ifdef CK_USE_FNUZ_FP8 -#define CK_USE_FNUZ_FP8 1 -#else +#ifndef CK_USE_FNUZ_FP8 #define CK_USE_FNUZ_FP8 0 #endif -#ifdef CK_USE_OCP_FP8 -#define CK_USE_OCP_FP8 1 -#else +#ifndef CK_USE_OCP_FP8 #define CK_USE_OCP_FP8 0 #endif @@ -432,7 +428,7 @@ __host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a) namespace fp8_impl { // Assertions to check for supported conversion types -#define __assert_ocp_support(interp) \ +#define __fp8_impl_assert_ocp_support(interp) \ { \ if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \ interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \ @@ -440,7 +436,7 @@ namespace fp8_impl { __hip_assert(false && "type is unsupported by current target device"); \ } \ } -#define __assert_fnuz_support(interp) \ +#define __fp8_impl_assert_fnuz_support(interp) \ { \ if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \ interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \ @@ -454,10 +450,10 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) { #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ #if CK_USE_OCP_FP8 - __assert_ocp_support(interp); + __fp8_impl_assert_ocp_support(interp); #endif #if CK_USE_FNUZ_FP8 - __assert_fnuz_support(interp); + __fp8_impl_assert_fnuz_support(interp); #endif #endif } From 50fad035248b154cdfa4505cf5de7465ce146149 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 25 Jun 2025 15:19:21 +0800 Subject: [PATCH 14/21] [CK_TILE] Add missing parameter 'min_seqlen_q' to the FMHA fwd kernel MakeKargs() interface (#2403) * Rename batch_prerfill interface * Add min_seqlen_q parameter in MakeKargs() --- example/ck_tile/01_fmha/fmha_fwd.hpp | 170 ++++++++--------- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 174 +++++++++--------- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 5 +- 3 files changed, 176 insertions(+), 173 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5ce56d48b5..15b028fa9f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -715,102 +715,102 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, #if 0 // we assume page_block_size=1 for now args.kv_last_page_lens, args.page_block_size, #endif - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_k, - args.batch_stride_v, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, #if 0 // we assume page_block_size=1 for now args.kv_last_page_lens, args.page_block_size, #endif - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 7472c82114..0d0959ba27 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -316,56 +316,56 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + int32_t num_total_pages, + const void* kv_indptr, + const void* kv_page_indices, #if 0 // we assume page_block_size=1 for now const void* kv_last_page_lens, ck_tile::index_t page_block_size, #endif - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -468,51 +468,51 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + int32_t num_total_pages, + const void* kv_indptr, + const void* kv_page_indices, #if 0 // we assume page_block_size=1 for now const void* kv_last_page_lens, ck_tile::index_t page_block_size, #endif - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fe426f925e..6dc014c9de 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -808,6 +808,7 @@ struct FmhaFwdKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple& drop_seed_offset) @@ -847,7 +848,7 @@ struct FmhaFwdKernel window_size_left, window_size_right, mask_type, - 0, // min_seqlen_q + min_seqlen_q, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); @@ -890,6 +891,7 @@ struct FmhaFwdKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple& drop_seed_offset) @@ -929,6 +931,7 @@ struct FmhaFwdKernel window_size_left, window_size_right, mask_type, + min_seqlen_q, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); From 37e1a2753702f003b751425502e037f2384aaa5f Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 25 Jun 2025 16:07:45 +0800 Subject: [PATCH 15/21] [CK_TILE] Refine fp8 support in flatmm (#2239) * [CK_TILE] Refine fp8 in flatmm 1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr 2. Add an additional const check to avoid build error in HotLoopScheduler 3. Refine shuffleb to support both tile 32x32 and 16x16 4. Support command option -init 5. Move Gemm warp defintion to a separate struct * fix clang format * fix clang format * keep default bhavior unchanged (warp tile = 16x16) * fix tile engine build error * fix a typo in codegen_utils.py * address review comments * address review comments --------- Co-authored-by: Thomas Ning --- example/ck_tile/18_flatmm/CMakeLists.txt | 2 - example/ck_tile/18_flatmm/flatmm_basic.cpp | 44 +++++-- example/ck_tile/18_flatmm/flatmm_basic.hpp | 109 +++++++++------- .../ck_tile/18_flatmm/run_flatmm_example.inc | 91 +++++++++----- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 119 +++++++++--------- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 115 ++++++++++------- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 13 +- tile_engine/ops/gemm/codegen_utils.py | 3 + tile_engine/ops/gemm/gemm_instance_builder.py | 11 +- tile_engine/ops/gemm/gemm_profiler.hpp | 4 +- 10 files changed, 313 insertions(+), 198 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 58e06f3c0f..6d6b71ea18 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,6 +3,4 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -Wno-unused-local-typedef) -#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 8782d2bb6a..f96f558101 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -17,12 +17,12 @@ template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { - using FlatmmConfig = FlatmmConfig; using CodegenFlatmmShape = ck_tile::TileFlatmmShape< ck_tile::sequence, ck_tile::sequence, @@ -32,18 +32,20 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; - const auto Run = [&](const auto memory_operation_) { + + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using GemmEpilogue = ck_tile::CShuffleEpilogue< @@ -151,6 +153,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con } } +template