From e6bcd192d432561642d45ea5b1c759d6f80ace2a Mon Sep 17 00:00:00 2001 From: ZheWang <35656954+eeezio@users.noreply.github.com> Date: Mon, 2 Feb 2026 16:04:40 +0800 Subject: [PATCH 1/9] Mx fp6 flatmm (#3601) * add fp6 data-type and support sync/async dwordx3 load/store * clang-format * pre-commit * 1st commit * default mnk pass ut * fix a distrubution * fix * fix bdram distr * update * pass ut * improve perf * update * clean code * resolve copilot comment * reslove comment * clang-format --------- Co-authored-by: ZheWang --- example/ck_tile/18_flatmm/CMakeLists.txt | 1 - .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 18 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp | 32 ++++ .../18_flatmm/mxgemm/mx_flatmm_instance.cmake | 3 +- .../mxgemm/mx_flatmm_instance.cpp.in | 1 + .../18_flatmm/mxgemm/run_mx_flatmm.inc | 51 ++++-- include/ck_tile/core.hpp | 1 + .../core/arch/amd_buffer_addressing.hpp | 11 +- .../arch/amd_buffer_addressing_builtins.hpp | 61 ++++++- include/ck_tile/core/numeric/pk_fp6.hpp | 109 ++++++++++++ include/ck_tile/core/numeric/type_convert.hpp | 1 + include/ck_tile/core/numeric/vector_type.hpp | 34 ++++ include/ck_tile/core/tensor/buffer_view.hpp | 35 +++- include/ck_tile/host/check_err.hpp | 53 ++++++ .../ck_tile/host/reference/reference_gemm.hpp | 22 +++ include/ck_tile/ops/common/utils.hpp | 1 + ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 90 +++++----- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 109 ++++++++---- .../warp/warp_gemm_attribute_mfma_impl.hpp | 3 +- test/ck_tile/memory_copy/test_copy.cpp | 101 ++++++++--- test/ck_tile/memory_copy/test_copy.hpp | 160 +++++++++++++++++- 21 files changed, 761 insertions(+), 136 deletions(-) create mode 100644 include/ck_tile/core/numeric/pk_fp6.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 7451ee25b0..d77e3c9322 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -21,7 +21,6 @@ if(has_supported_gpu) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") - add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index d6c84f3064..1141717545 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor& src) const int K = src_lengths[0]; const int N = src_lengths[1]; constexpr int packed_size = ck_tile::numeric_traits::PackedSize; - int KPack = 16 * packed_size; // fp4:32 or fp8:16 - int NLane = N_Warp_Tile; - int KLane = 64 / NLane; - int K0 = K / (KLane * KPack); + int KPack = + std::is_same_v ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16 + int NLane = N_Warp_Tile; + int KLane = 64 / NLane; + int K0 = K / (KLane * KPack); ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); @@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[]) } else if(mx_prec == "fp6" || mx_prec == "fp6xfp6") { - throw std::runtime_error("fp6xfp6 is not supported."); + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); } else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index 0b6185590f..d4922bb44c 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct MXfp6_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + struct MXfp8_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake index 5e86cd7133..9250dbe7ae 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST) set(C_LAYOUT ROW) set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16") + set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16") set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16") # foreach(PERSISTENT false true) # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. foreach(PERSISTENT false) - foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8) + foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8) set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}}) string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE}) list(GET DATA_TYPE_AB 0 A_DATA_TYPE) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index 9675d3345b..e6d612f0d6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -19,6 +19,7 @@ using FP4 = ck_tile::pk_fp4_t; using FP8 = ck_tile::fp8_t; +using FP6 = ck_tile::pk_fp6x16_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t; diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index b4d1fe237b..54c23e2266 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc, M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); + if constexpr(std::is_same_v) + { + auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); + auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); + std::vector random_bufA(a_buffer_bytes); + std::vector random_bufB(b_buffer_bytes); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(1, 4); - if(init_method == 0) - { - ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); - } - else if(init_method == 1) - { - ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + for(size_t i = 0; i < a_buffer_bytes; ++i) + random_bufA[i] = static_cast(dis(gen)); + + for(size_t i = 0; i < b_buffer_bytes; ++i) + random_bufB[i] = static_cast(dis(gen)); + + memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); + memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); } else { - throw std::runtime_error("wrong! Unexpected init_method"); + if(init_method == 0) + { + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + } + else + { + throw std::runtime_error("wrong! Unexpected init_method"); + } } const auto b_shuffled_host = preShuffleWeight(b_origin_host); diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 438e44f5f1..91212292d2 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,6 +54,7 @@ #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 8f9dd30bda..3a7231f71d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index bdc0daaed2..e26ac2e600 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1134,6 +1134,25 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32"); + +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3(dwordx3_union vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset) +{ + int32x3_t v_reg; + v_reg[0] = vdata.as_i32[0]; + v_reg[1] = vdata.as_i32[1]; + v_reg[2] = vdata.as_i32[2]; + llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, 0); +}; + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, @@ -1290,7 +1309,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1330,6 +1349,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + dwordx3_union ret; + ret.as_i32[0] = tmp[0]; + ret.as_i32[1] = tmp[1]; + ret.as_i32[2] = tmp[2]; + return bit_cast(ret); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, @@ -1411,15 +1442,19 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && (N == 1)), "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1750,7 +1785,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer(coherence)); } + else if constexpr(N == 12) + { + llvm_amdgcn_raw_buffer_store_i32x3(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); + } else if constexpr(N == 16) { llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), @@ -1859,10 +1901,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + std::is_same::value && (N == 1), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp new file mode 100644 index 0000000000..0de61f6b1f --- /dev/null +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +namespace ck_tile { +template +struct pk_fp6_t +{ + static constexpr index_t num_bits_elem = 6; + using element_type = int32_t; // element storage fundamental type + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_vec_elem = + sizeof(element_type) * 8; // 32-bit uint for storage + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + element_type data_[vector_size]; // packed data + using type = pk_fp6_t; + CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0) + { + for(size_t i = 0; i < vector_size; ++i) + { + data_[i] = value; + } + } + CK_TILE_HOST_DEVICE void pack(const int32_t x, const index_t i) + { + int32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + int32_t old_value = data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) + { + int32_t next_value = data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data_[arr_index + 1] = next_value; + } + } + + template + CK_TILE_HOST_DEVICE static int32_t unpack(const T& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + int32_t bits = pk.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); + } + + return bits & 0x3F; + } + + CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); } + + CK_TILE_HOST_DEVICE int32_t operator[](index_t i) const { return data_[i]; } + + CK_TILE_HOST_DEVICE static float fp6_e2m3_to_float(int32_t fp6_bits) + { + fp6_bits = fp6_bits & 0x3F; + + uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5 + uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3 + uint32_t mantissa = fp6_bits & 0x7; // bits 2-0 + + float result; + if(exponent == 0 && mantissa == 0) + { + result = 0.f; + } + else if(exponent != 0) + { + result = std::exp2f(static_cast(exponent) - 1); + float mantissa_value = 1.0f + mantissa / 8.0f; + result *= mantissa_value; + } + else + { + result = mantissa / 8.0f; + } + return sign == 1 ? -1 * result : result; + } +}; + +using pk_fp6x16_t = pk_fp6_t<16>; +using pk_fp6x32_t = pk_fp6_t<32>; +template <> +struct numeric_traits +{ + static constexpr int PackedSize = 16; +}; +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index deaa9e0bd9..634b845725 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -72,6 +72,7 @@ CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index def054f415..756bc7f6fc 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -160,6 +160,40 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x64_t = int32_t __attribute__((ext_vector_type(64))); +struct int32x3_tt +{ + int32_t data[3]; +}; + +struct int32x6_tt +{ + int32_t data[6]; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 12; + using value_type = int32x3_tt; + using type = int32x3_tt; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 1; + using value_type = int32x3_tt; + using type = int32x3_tt; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 2; + using value_type = int32x6_tt; + using type = int32x6_tt; +}; + // u32 // using uint32_t = ... using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index f3aeed6e61..59f82939b9 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -303,7 +303,6 @@ struct buffer_view>::scalar_type, - scalar_per_t_vector * scalar_per_x_vector>; - // using buf_t = ushort __attribute__((ext_vector_type(8))); - auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); - return bit_cast(rtn); + constexpr index_t load_elts = scalar_per_t_vector * scalar_per_x_vector; + if constexpr(load_elts == 12 && sizeof(typename X::value_type) == 1) + { + auto rtn = reinterpret_cast(p_data_) + (i + linear_offset) / 4; + struct + { + int32_t x, y, z; + } tmp = {rtn[0], rtn[1], rtn[2]}; + return bit_cast(tmp); + } + else + { + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); + return bit_cast(rtn); + } #endif } else @@ -968,6 +979,7 @@ struct buffer_view, int8x16_t> && std::is_same_v, int8x16_t>) || // int8 on thread buffer (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || @@ -1033,6 +1045,11 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else if constexpr(std::is_same_v, thread_buffer>) + { + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } else if constexpr((std::is_same_v, int8_t> && std::is_same_v, int8x16_t>) || (std::is_same_v, int8_t> && @@ -1075,6 +1092,12 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else + { + static_assert(false, + "wrong! not implemented for this combination, please add " + "implementation"); + } } } else diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 2ba3b1e7c3..a2f6728316 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -720,4 +720,57 @@ std::enable_if_t<(std::is_same_v, ranges::range_val return err_count == 0; } +/** + * @brief Check errors between pk_fp6x16_t ranges + * + * Compares two ranges of pk_fp6x16_t without tolerance. + * This specialization handles ck_tile::pk_fp6x16_t type. + * + * @tparam Range Type of output range + * @tparam RefRange Type of reference range + * @param out Output range to check + * @param ref Reference range to check against + * @param msg Error message to display if check fails + * @return True if check passes, false otherwise + */ +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, pk_fp6x16_t>), + bool> + CK_TILE_HOST check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double = 0) +{ + if(check_size_mismatch(out, ref, msg)) + return false; + + int err_count = 0; + float max_err = 0.0f; + auto update_err = [&](float o, float r, std::size_t index) { + if(std::fabs(o - r) > 1e-8) + { + std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o + << " != " << r << std::endl; + ++err_count; + max_err = max_err < std::fabs(o - r) ? o : max_err; + } + }; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const pk_fp6x16_t o = *std::next(std::begin(out), i); + const pk_fp6x16_t r = *std::next(std::begin(ref), i); + for(std::size_t j = 0; j < numeric_traits::PackedSize; j++) + { + update_err(o.unpack(j), r.unpack(j), i * numeric_traits::PackedSize + j); + } + } + if(err_count > 0) + { + report_error_stats(err_count, max_err, ref.size()); + } + return err_count == 0; +} + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 7830150b63..da6b074b98 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -625,6 +625,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, a_m_k_scaled(m, k) = a_f4_lo * a_scale; a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto a_scale = ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + a_m_k_scaled(m, k + k_) = + pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale; + } + } else { a_m_k_scaled(m, k) = @@ -653,6 +664,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, b_k_n_scaled(k, n) = b_f4_lo * b_scale; b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto b_scale = ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + b_k_n_scaled(k + k_, n) = + pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale; + } + } else { b_k_n_scaled(k, n) = diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 425083a9de..4a30e3af16 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -22,6 +22,7 @@ template <> struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp6x16"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index bc7d2323d0..23d7a9fca9 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -118,8 +118,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 + ? 16 + : 16 /*dwordx4*/ * APackedSize / sizeof(ADataType); + static constexpr index_t BK1 = std::is_same_v + ? 16 + : 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -537,24 +542,26 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_store_lds_window_pong = make_tile_window( // a_lds_block_pong, - make_tuple(number{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); - auto a_warp_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_ping = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_pong = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); // B flat DRAM window for load @@ -621,7 +628,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { @@ -663,7 +670,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> @@ -683,7 +690,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); __builtin_amdgcn_sched_barrier(0); @@ -750,7 +758,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -760,7 +768,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); @@ -839,7 +848,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_pong, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished @@ -849,7 +858,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); }; @@ -874,7 +884,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 0); } - // TAIL if constexpr(TailNum == TailNumber::Even) { @@ -933,7 +942,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -947,7 +956,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); Last2ndHotLoopScheduler(); @@ -977,12 +987,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_pong, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, + number>{}); } }); LastHotLoopScheduler(); @@ -1014,12 +1024,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_ping, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, + number>{}); } }); LastHotLoopScheduler(); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 34d18cb8e1..7cf6326dfd 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -17,6 +17,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr index_t kDramLoadPackBytes = 128; static constexpr index_t DWORDx4 = 16; + static constexpr index_t DWORDx3 = 12; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; @@ -77,15 +78,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution() { - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + constexpr index_t K0 = + KPerBlock / APackedSize * sizeof(ADataType) / (K1 * K2); // KPerBlock/256/packsize constexpr index_t M2 = WaveSize / K1; // 8 constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + static_assert(K0 * K1 * K2 == KPerBlock / APackedSize * sizeof(ADataType), "K0, K1, K2 must cover whole KPerBlock!"); return make_static_tile_distribution( @@ -107,9 +109,9 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - const index_t K0 = cols / (K1 * K2 * APackedSize); + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + const index_t K0 = cols / (K1 * K2 / sizeof(ADataType) * APackedSize); const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds @@ -138,19 +140,23 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); - auto&& origin_tmp = window_tmp.get_window_origin(); + auto&& origin_tmp = window_tmp.get_window_origin(); + constexpr index_t test1 = APackedSize / sizeof(ADataType); return make_tile_window(byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / APackedSize}, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / test1}, MakeMX_ABytesDramTileDistribution()); } CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor() { - constexpr index_t K2 = AK1 / APackedSize; // 16 - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256 - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + constexpr index_t K2 = std::is_same_v ? DWORDx3 : AK1 / APackedSize; + constexpr index_t K2_Pad = 16; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + constexpr index_t K0 = std::is_same_v + ? KPerBlock / (K1 * K2 / sizeof(ADataType) * APackedSize) + : KPerBlock / (K1 * AK1); // KPerBlock/256 + static_assert(K0 * K1 * K2 / sizeof(ADataType) * APackedSize == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); constexpr index_t M3 = 4; // so that we can use imm offset to load lds @@ -169,12 +175,12 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - number{}, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, number<1>{}), number{}, number<1>{}); @@ -216,7 +222,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - if constexpr(K_Thread == AK1) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -225,7 +231,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<0, 2>>, sequence<2>, sequence<1>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< sequence, @@ -235,6 +241,19 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<1, 2>>, sequence<2, 2>, sequence<0, 2>>{}); + else if constexpr(std::is_same_v) + // K_Lane=4, K_Thread=32 + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<2, 2>, + sequence<1, 2>>{}); + else + static_assert(false, "unsupported datatype"); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() @@ -245,17 +264,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - if constexpr(BK1 == K_Thread) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 16 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -265,6 +284,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>{}); + else if constexpr(std::is_same_v) + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 64 1 2 12 + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<2, 2>, + sequence<2, 3>>{}); + else + static_assert(false, "unsupported datatype"); } template @@ -280,21 +314,27 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple( - flat_n, flat_k / flat_k_per_block, number{})), + make_naive_tensor_descriptor_packed( + make_tuple(flat_n, + flat_k / flat_k_per_block, + number{})), make_tuple(make_pass_through_transform(flat_n), make_merge_transform_v3_division_mod(make_tuple( - flat_k / flat_k_per_block, number{}))), + flat_k / flat_k_per_block, + number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, byte_tensor_desc); auto&& origin_tmp = window_tmp.get_window_origin(); + auto origin_n = origin_tmp[0]; + auto origin_k = static_cast(origin_tmp[1] * sizeof(BDataType) / BPackedSize); return make_tile_window( byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / BPackedSize}, + make_tuple(number{}, + number{}), + {origin_n, origin_k}, MakeMX_BFlatBytesDramTileDistribution()); } @@ -372,7 +412,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + if constexpr(!std::is_same_v) + { + return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } + else + { + return MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 9e23a06b23..24076ca494 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1614,7 +1614,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 return make_tuple(number<0>{}, int32x8_t{}); else if constexpr(std::is_same_v) return make_tuple(number<1>{}, int32x8_t{}); - // else if e2m3 => make_tuple(number<2>{}, int32x6_t{}) + else if constexpr(std::is_same_v) + return make_tuple(number<2>{}, pk_fp6x32_t{}); // else if e3m2 => make_tuple(number<3>{}, int32x6_t{}) else if constexpr(std::is_same_v) return make_tuple(number<4>{}, int32x4_t{}); diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index 2a43b596e4..208b92e702 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -20,6 +20,25 @@ struct MemoryCopyParam ck_tile::index_t warp_id; }; +template +struct type_list +{ +}; + +template +struct type_at; + +template +struct type_at> : type_at> +{ +}; + +template +struct type_at<0, type_list> +{ + using type = Head; +}; + template class TestCkTileMemoryCopy : public ::testing::TestWithParam> { @@ -33,48 +52,47 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam ? 1 : 0; ck_tile::HostTensor x_host({m, n}); ck_tile::HostTensor y_host_dev({m, n}); + ck_tile::HostTensor host_init_buf({x_host.get_element_space_size_in_bytes()}); std::cout << "input: " << x_host.mDesc << std::endl; std::cout << "output: " << y_host_dev.mDesc << std::endl; - ck_tile::index_t value = 1; - for(int i = 0; i < m; i++) - { - value = 1; - for(int j = 0; j < n; j++) - { - value = (value + 1) % 127; - x_host(i, j) = static_cast(value); - } - } - + for(size_t i = 0; i < x_host.get_element_space_size_in_bytes(); i++) + host_init_buf.mData[i] = i % 64; + memcpy(x_host.mData.data(), + host_init_buf.mData.data(), + x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, dword_bytes / sizeof(DataType)>; + using BlockTileList = type_list, ck_tile::sequence<16, 96>>; + using VectorList = type_list, + ck_tile::sequence<1, 24>>; + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = type_at::type; + using WaveTile = type_at::type; + using Vector = type_at::type; ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; + using Problem = ck_tile::TileCopyProblem; using Kernel = ck_tile::TileCopy; constexpr ck_tile::index_t kBlockSize = 128; constexpr ck_tile::index_t kBlockPerCu = 1; + // when copy fp6x16 buffer, tread it as int8 buffer and recompute n-dim size. + ck_tile::index_t cpy_n = + CpyCfg == 1 ? n * sizeof(DataType) / + (sizeof(int8_t) * ck_tile::numeric_traits::PackedSize) + : n; auto ms = launch_kernel( ck_tile::stream_config{nullptr, true}, @@ -85,21 +103,28 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam(x_buf.GetDeviceBuffer()), static_cast(y_buf.GetDeviceBuffer()), m, - n, + cpy_n, warp_id)); - auto bytes = 2 * m * n * sizeof(DataType); + auto bytes = 2 * m * n * sizeof(DataType) / ck_tile::numeric_traits::PackedSize; std::cout << "elapsed: " << ms << " (ms)" << std::endl; std::cout << (bytes * 1e-6 / ms) << " (GB/s)" << std::endl; // reference y_buf.FromDevice(y_host_dev.mData.data()); bool pass = ck_tile::check_err(y_host_dev, x_host); - EXPECT_TRUE(pass); } }; +class TestCkTileMemoryCopyF6x16Async : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyF6x16 : public TestCkTileMemoryCopy +{ +}; + class TestCkTileMemoryCopyHalfAsync : public TestCkTileMemoryCopy { }; @@ -116,6 +141,18 @@ class TestCkTileMemoryCopyFP8Async : public TestCkTileMemoryCopy { }; +TEST_P(TestCkTileMemoryCopyF6x16, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyF6x16Async, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + TEST_P(TestCkTileMemoryCopyHalfAsync, TestCorrectness) { auto [M, N, warp_id] = GetParam(); @@ -140,6 +177,20 @@ TEST_P(TestCkTileMemoryCopyFP8Async, TestCorrectness) this->Run({M, N, warp_id}); } +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16Async, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, TestCkTileMemoryCopyHalfAsync, ::testing::Values(std::tuple{64, 8, 0}, diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index 847763881b..2ce4982a04 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -51,12 +51,15 @@ struct TileCopyShape "Inconsistent wave group size!"); }; -template +template struct TileCopyProblem { using XDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool AsyncCopy = AsyncCopy_; + // 0: copy 1, 2, 4 bytes data type + // 1: copy dwordx3 bytes data type + static constexpr int CpyCfg = CpyCfg_; }; template @@ -67,6 +70,7 @@ struct TileCopy static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr bool AsyncCopy = Problem::AsyncCopy; + static constexpr int CpyCfg = Problem::CpyCfg; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() @@ -98,8 +102,40 @@ struct TileCopy return make_static_tile_distribution(outer_encoding); } + template + // CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Block_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // number of active warps working in this thread block. + constexpr index_t Y2 = + warp_size / X0; // number of threads in a warp needed along M dimension. + constexpr index_t Y1 = + S::Warp_M / + Y2; // number of iterations each warp needs to perform to cover the entire tile window. + constexpr auto outer_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, // Y2==16,X0==4 + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}; + + return make_static_tile_distribution(outer_encoding); + } + CK_TILE_DEVICE void - operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + run_normal_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const { using S = typename Problem::BlockShape; @@ -170,6 +206,124 @@ struct TileCopy move_tile_window(y_block_window, {0, S::Block_N}); } } -}; + CK_TILE_DEVICE void + run_dwordx3_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + constexpr index_t X0 = S::ThreadPerWarp_N; + constexpr index_t X1 = S::Block_N; + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + // LDS buffer + constexpr int dim1_stride = + AsyncCopy ? 16 : 12; // async_load dwordx3 will write 3 bytes & skip 1 bytes in lds. + constexpr int repeat_num = X1 / (X0 * X2); + __shared__ int8_t x_lds[repeat_num * S::Block_M * X0 * dim1_stride]; + + constexpr auto block_dims = make_tuple(number{}, number{}); + constexpr auto block_dims_ = make_tuple(number{}, + number{}, + number{}, + number{}); + constexpr auto block_strides = make_tuple(number{}, + number{}, + number{}, + number<1>{}); + + const auto x_lds_desc_ = + make_naive_tensor_descriptor(block_dims_, block_strides, number<12>{}, number<1>{}); + const auto x_lds_desc = transform_tensor_descriptor( + x_lds_desc_, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number<2>{}, number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = + make_tensor_view(reinterpret_cast(x_lds), x_lds_desc); + + auto x_block_lds_write_window = make_tile_window(x_lds_view, block_dims, {0, 0}); + + auto x_block_lds_read_window = make_tile_window( + x_lds_view, block_dims, {0, 0}, MakeDwordx3DRAMDistribution()); + + const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M); + // Input tensor + const auto x_m_n = + make_naive_tensor_view(reinterpret_cast(p_x), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, block_dims, {iM, 0}, MakeDwordx3DRAMDistribution()); + + // Output tensor + const auto y_m = + make_naive_tensor_view(reinterpret_cast(p_y), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto y_block_window = make_tile_window(y_m, block_dims, {iM, 0}); + + const index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + const index_t my_id = __builtin_amdgcn_readfirstlane(get_warp_id()); + constexpr index_t async_copy_fence_cnt = 0; + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + if(my_id == warp_id) + { + if constexpr(AsyncCopy) + { + async_load_tile(x_block_lds_write_window, x_block_window); + // We don't have prefetch here, wait the data back immediately. + // Wait all asyncload insts complete. + // Wait all waves synced + s_waitcnt_barrier(); + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + else + { + // load from DRAM to registers + auto dram_tile = load_tile(x_block_window); + // store in lds + store_tile(x_block_lds_write_window, dram_tile); + // Wait all lds write insts complete + // Wait all waves synced + block_sync_lds(); + // read from lds to registers + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + } + + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } + + CK_TILE_DEVICE void + operator()(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + if constexpr(CpyCfg == 1) + { + run_dwordx3_cpy(p_x, p_y, M, N, warp_id); + } + else if constexpr(CpyCfg == 0) + { + run_normal_cpy(p_x, p_y, M, N, warp_id); + } + else + { + static_assert(false, "unsupported copy config type."); + } + } +}; } // namespace ck_tile From 069500464de6a55b80e8341c79239b13ac8ef379 Mon Sep 17 00:00:00 2001 From: Jan Patrick Lehr Date: Mon, 2 Feb 2026 18:39:48 +0100 Subject: [PATCH 2/9] [Compiler] Addressing new compiler warnings (#3640) * [Compiler] Addressing new compiler warnings Clang enables new lifetime warnings in production and we see build errors due to this with the staging compiler. The attributes added in this PR are suggested by the compiler. However, I'm not very familiar with the code base, so the changes may be incorrect. * Update some more instances * Adds file-level ignores via clang diagnostic pragma The number of instances was large, so I decided to use file-level scope to disable the warning via pragma clang diagnostic ignored. It also showed this warning coming from the gtest dependency. For that, I did add the respective command line flag to the CMake variables. I don't know if this is acceptable or not. * This adds the remaining instances For a build on gfx90a. * fix clang format * Adding couple more instances from gfx1200 build * Fixed another few instances --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng --- cmake/gtest.cmake | 2 + example/ck_tile/01_fmha/bias.hpp | 2 +- example/ck_tile/01_fmha/mask.hpp | 2 +- example/ck_tile/01_fmha/quant.hpp | 4 ++ include/ck/host_utility/io.hpp | 5 ++- .../library/utility/convolution_parameter.hpp | 3 +- include/ck/library/utility/host_tensor.hpp | 12 ++++-- include/ck/tensor/static_tensor.hpp | 3 ++ .../multi_index_transform.hpp | 39 +++++++++++++++---- .../ck/tensor_description/tensor_adaptor.hpp | 5 ++- .../tensor_description/tensor_descriptor.hpp | 17 ++++++-- .../blockwise_gemm_pipeline_wmmaops_base.hpp | 3 ++ .../block/blockwise_gemm_pipeline_xdlops.hpp | 3 ++ .../blockwise_gemm_pipeline_xdlops_base.hpp | 4 ++ .../gpu/block/blockwise_gemm_wmma.hpp | 4 ++ .../gpu/block/blockwise_gemm_xdlops.hpp | 3 ++ .../blockwise_gemm_xdlops_skip_b_lds.hpp | 4 ++ .../gpu/device/tensor_layout.hpp | 2 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 4 ++ .../ck/utility/amd_wave_read_first_lane.hpp | 3 +- include/ck/utility/dtype_vector.hpp | 21 +++++----- include/ck/utility/env.hpp | 5 ++- include/ck/utility/pipeline_enum.hpp | 3 +- include/ck/utility/scheduler_enum.hpp | 3 +- include/ck/utility/static_buffer.hpp | 5 ++- include/ck/utility/tuple.hpp | 7 ++-- include/ck/wrapper/layout.hpp | 4 ++ include/ck/wrapper/tensor.hpp | 4 ++ .../core/algorithm/coordinate_transform.hpp | 4 ++ include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 4 ++ include/ck_tile/core/container/map.hpp | 4 ++ include/ck_tile/core/container/tuple.hpp | 11 ++++-- include/ck_tile/core/numeric/e8m0.hpp | 4 ++ include/ck_tile/core/numeric/pk_fp4.hpp | 4 ++ .../core/tensor/static_distributed_tensor.hpp | 4 ++ .../ck_tile/core/tensor/tensor_adaptor.hpp | 4 ++ .../core/tensor/tensor_adaptor_coordinate.hpp | 4 ++ include/ck_tile/core/tensor/tensor_view.hpp | 4 ++ .../ck_tile/core/tensor/tile_distribution.hpp | 4 ++ include/ck_tile/core/utility/env.hpp | 4 ++ include/ck_tile/core/utility/functional.hpp | 3 ++ include/ck_tile/host/arg_parser.hpp | 4 ++ include/ck_tile/host/host_tensor.hpp | 4 ++ .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 ++- profiler/src/profiler_operation_registry.hpp | 4 ++ .../position_embedding/position_embedding.cpp | 4 ++ .../gemm_multi_d/gemm_multi_d_benchmark.hpp | 4 ++ .../gemm_preshuffle_benchmark.hpp | 4 ++ .../gemm/gemm_universal/gemm_benchmark.hpp | 3 ++ .../gemm_streamk/gemm_streamk_benchmark.hpp | 4 ++ 50 files changed, 228 insertions(+), 43 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 993330f989..51e0359ab6 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,6 +68,8 @@ set(GTEST_CXX_FLAGS -Wno-deprecated -Wno-unsafe-buffer-usage -Wno-float-equal + -Wno-lifetime-safety-intra-tu-suggestions + -Wno-lifetime-safety-cross-tu-suggestions ) if(WIN32) diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp index 33f398cc2a..b526204384 100644 --- a/example/ck_tile/01_fmha/bias.hpp +++ b/example/ck_tile/01_fmha/bias.hpp @@ -106,7 +106,7 @@ struct bias_info return info; } - friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const bias_info& bi) { bi.serialize(os); return os; diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index f85b811116..c780bf7b6b 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -191,7 +191,7 @@ struct mask_info return area; } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi) { mi.serialize(os); return os; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index feb28cba24..da588910b2 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -8,6 +8,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // keep sync with BlockAttentionQuantScaleEnum enum class quant_scale_enum { @@ -58,3 +61,4 @@ struct quant_scale_info return os; } }; +#pragma clang diagnostic pop diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index db45199b17..22d744ff15 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -13,7 +13,7 @@ namespace ck { template -std::ostream& operator<<(std::ostream& os, const std::vector& v) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const std::vector& v) { std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); return os; @@ -27,7 +27,8 @@ std::ostream& operator<<(std::ostream& os, const std::array& v) } template -std::ostream& operator<<(std::ostream& os, const TensorDescriptor& desc) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const TensorDescriptor& desc) { constexpr index_t nDim = remove_cvref_t::GetNumOfDimension(); diff --git a/include/ck/library/utility/convolution_parameter.hpp b/include/ck/library/utility/convolution_parameter.hpp index 354b112040..a25002409b 100644 --- a/include/ck/library/utility/convolution_parameter.hpp +++ b/include/ck/library/utility/convolution_parameter.hpp @@ -110,4 +110,5 @@ ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) } // namespace utils } // namespace ck -std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p); +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::utils::conv::ConvParam& p); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 1dda0a4863..2e95ee8cf3 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -23,10 +23,14 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" +#pragma clang diagnostic ignored "-Wlifetime-safety-cross-tu-suggestions" + namespace ck { template -std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +std::ostream& LogRange([[clang::lifetimebound]] std::ostream& os, Range&& range, std::string delim) { bool first = true; for(auto&& v : range) @@ -580,8 +584,9 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } - friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); - friend std::ostream& operator<<(std::ostream& os, ChosenLayout tag); + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const HostTensorDescriptor& desc); + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, ChosenLayout tag); private: std::vector mLens; @@ -1171,3 +1176,4 @@ struct Tensor }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index 529745e3b9..c3f3bd0c91 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -4,6 +4,8 @@ #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // StaticTensor for Scalar @@ -270,4 +272,5 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_elem } } // namespace ck +#pragma clang diagnostic pop #endif diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 19a4748732..5a6c335b2c 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -6,6 +6,9 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/multi_index.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { template @@ -29,7 +32,10 @@ struct PassThrough __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -305,7 +311,10 @@ struct RightPad __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -403,7 +412,10 @@ struct Embed __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1074,7 +1086,10 @@ struct Merge_v2_magic_division __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1366,7 +1381,10 @@ struct Merge_v3_division_mod __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1480,7 +1498,10 @@ struct UnMerge __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1640,7 +1661,10 @@ struct ConvBwdDataImplicitGemmOutTransform __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const @@ -2236,3 +2260,4 @@ struct Xor } }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 79c5881d48..ee8c7ed71b 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -23,7 +23,10 @@ struct TensorAdaptor { __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } - __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + __host__ __device__ constexpr const auto& GetTransforms() const [[clang::lifetimebound]] + { + return transforms_; + } __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() { diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d11..a237c4219d 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -7,6 +7,8 @@ #include "ck/utility/sequence_helper.hpp" #include "ck/tensor_description/multi_index_transform.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { template @@ -179,7 +181,10 @@ struct TensorDescriptor } // TODO make these private - __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + __host__ __device__ constexpr const auto& GetTransforms() const [[clang::lifetimebound]] + { + return transforms_; + } __host__ __device__ static constexpr auto GetLowerDimensionIdss() { @@ -253,9 +258,12 @@ struct TensorCoordinate __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } // TODO make these private - __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + __host__ __device__ constexpr const auto& GetHiddenIndex() const [[clang::lifetimebound]] + { + return idx_hidden_; + } - __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + __host__ __device__ auto& GetHiddenIndex() [[clang::lifetimebound]] { return idx_hidden_; } __host__ __device__ constexpr auto GetVisibleIndex() const { @@ -284,7 +292,7 @@ struct TensorCoordinateStep __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } // TODO make these private - __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const [[clang::lifetimebound]] { return idx_diff_visible_; } @@ -613,3 +621,4 @@ using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index f831c0f6cf..e41cf8c82d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -10,6 +10,8 @@ #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { template @@ -1031,3 +1033,4 @@ struct BlockwiseGemmXdlops_v2 }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 1dba7f67a1..65a326e3e7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -8,6 +8,9 @@ #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { template ::value, bool>::type = false> -std::ostream& operator<<(std::ostream& os, const Layout&) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const Layout&) { os << Layout::name; return os; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 8c316bc71d..6060889c10 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -17,6 +17,9 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to @@ -1132,3 +1135,4 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight }; // namespace ck } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 44259f0601..4b64b76cc7 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -44,7 +44,8 @@ struct get_carrier<3> // replacement of host std::copy_n() template - __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + __device__ static OutputIterator + copy_n(InputIterator from, Size size, [[clang::lifetimebound]] OutputIterator to) { if(0 < size) { diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ebdbbb107d..204b199629 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck/utility/data_type.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // vector_type @@ -116,7 +118,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -136,7 +138,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -248,7 +250,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -272,7 +274,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -583,7 +585,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, @@ -754,7 +756,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || @@ -1427,7 +1429,7 @@ struct non_native_vector_base< } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); @@ -1627,7 +1629,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, @@ -1797,7 +1799,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || @@ -2284,3 +2286,4 @@ using pk_i4x4_t = typename vector_type::type; using pk_i4x8_t = typename vector_type::type; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 0cb0b4caf8..4cabd89e33 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -9,6 +9,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { namespace internal { template @@ -188,5 +191,5 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) // environment variable to enable logging: // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - +#pragma clang diagnostic pop #endif diff --git a/include/ck/utility/pipeline_enum.hpp b/include/ck/utility/pipeline_enum.hpp index 4421386f59..a224011a04 100644 --- a/include/ck/utility/pipeline_enum.hpp +++ b/include/ck/utility/pipeline_enum.hpp @@ -25,7 +25,8 @@ enum struct PipelineVersion } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::PipelineVersion& p) { switch(p) { diff --git a/include/ck/utility/scheduler_enum.hpp b/include/ck/utility/scheduler_enum.hpp index 0c4bfabaf3..67c5c3b50a 100644 --- a/include/ck/utility/scheduler_enum.hpp +++ b/include/ck/utility/scheduler_enum.hpp @@ -70,7 +70,8 @@ enum struct TailNumber } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::LoopScheduler& s) { switch(s) { diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index d49817eb8f..7e47da5bf8 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -5,6 +5,8 @@ #include "statically_indexed_array.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // static buffer for scalar @@ -104,7 +106,7 @@ struct StaticBufferTupleOfVector // Set S // i is offset of S template - __host__ __device__ constexpr S& operator()(Number i) + __host__ __device__ constexpr S& operator()(Number i) [[clang::lifetimebound]] { constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; @@ -195,3 +197,4 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber) } } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 1657595030..16cd35e1d6 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -51,7 +51,7 @@ get_tuple_element_data_reference(const TupleElementKeyData& x) // for write access of tuple element template __host__ __device__ constexpr Data& -get_tuple_element_data_reference(TupleElementKeyData& x) +get_tuple_element_data_reference([[clang::lifetimebound]] TupleElementKeyData& x) { return x.mData; } @@ -106,6 +106,7 @@ struct TupleImpl, Xs...> : TupleElementKeyData __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) + [[clang::lifetimebound]] { return get_tuple_element_data_reference>(*this); } @@ -147,7 +148,7 @@ struct Tuple : detail::TupleImpl - __host__ __device__ constexpr auto& At(Number) + __host__ __device__ constexpr auto& At(Number) [[clang::lifetimebound]] { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); @@ -162,7 +163,7 @@ struct Tuple : detail::TupleImpl - __host__ __device__ constexpr auto& operator()(Number i) + __host__ __device__ constexpr auto& operator()(Number i) [[clang::lifetimebound]] { return At(i); } diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 334d5851db..6d99f4e5e3 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -5,6 +5,9 @@ #include "ck/wrapper/utils/layout_utils.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { @@ -482,3 +485,4 @@ struct Layout } // namespace wrapper } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 9f8278a357..ed7f2fa23d 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -7,6 +7,9 @@ #include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { @@ -441,3 +444,4 @@ struct Tensor } // namespace wrapper } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 732799cef8..30c93b8f00 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -11,6 +11,9 @@ #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/print.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { enum struct coord_transform_enum @@ -1776,3 +1779,4 @@ make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingA } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 4c9ef7d6ba..1eef5819bc 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -7,6 +7,9 @@ #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/ignore.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile::core::arch::mma { /** @@ -112,6 +115,7 @@ struct amdgcn_mma }; } // namespace ck_tile::core::arch::mma +#pragma clang diagnostic pop // Include the implementations #include "wmma/wmma.hpp" diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index d342235b38..8c861ceeb6 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -8,6 +8,9 @@ #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/tuple.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { // naive map @@ -157,3 +160,4 @@ CK_TILE_HOST_DEVICE static void print(const map& m) } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 7f8176d5ec..11e7b1e52f 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -13,6 +13,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #ifndef CK_TILE_TUPLE_IMPL #define CK_TILE_TUPLE_IMPL 1 #endif @@ -98,13 +101,14 @@ CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object&) } template -CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object& x) +CK_TILE_HOST_DEVICE constexpr const T& +getv([[clang::lifetimebound]] const tuple_object& x) { return x.element; } template -CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object& x) +CK_TILE_HOST_DEVICE constexpr T& getv([[clang::lifetimebound]] tuple_object& x) { return x.element; } @@ -292,7 +296,7 @@ struct tuple : impl::tuple_base, T...> //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast&>(*this).at(i); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } - + // template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(i) = x; } template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } @@ -864,3 +868,4 @@ struct tuple_element> } \ }() #endif +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp index 41aeb8ffab..ee12524283 100644 --- a/include/ck_tile/core/numeric/e8m0.hpp +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -6,6 +6,9 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /** @@ -100,3 +103,4 @@ CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index d74db6b336..5822e3b9bc 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -9,6 +9,9 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #if defined(__gfx950__) #define CK_TILE_FP4_CVT_DEVICE 1 #else @@ -517,3 +520,4 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const #endif } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 10c7587bcb..bdd81dae07 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/container/thread_buffer.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -266,3 +269,4 @@ inline constexpr bool is_similiar_distributed_tensor_v = } // namespace detail } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 78160b800d..e6cdb66ef9 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -12,6 +12,9 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { // Transforms: Tuple @@ -950,3 +953,4 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. remove_cvref_t, \ remove_cvref_t>{trans}; \ }() +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp index 2ea76a3814..6d33bde83e 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/print.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -367,3 +370,4 @@ CK_TILE_HOST_DEVICE void print(const tensor_adaptor_coordinate& coord) detail::CK_PRINT_X_<>{}(coord); } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 837f2b87a6..833a7f4413 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /* @@ -582,3 +585,4 @@ pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index f9c2aba502..aa5714e5c2 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -15,6 +15,9 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -731,3 +734,4 @@ CK_TILE_HOST_DEVICE void print(const tile_distribution #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -206,3 +209,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) // environment variable to enable logging: // export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index aa4bfa3f15..ae79d575a8 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -10,6 +10,8 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck_tile { namespace detail { @@ -270,3 +272,4 @@ constexpr auto conditional_expr(X&& x, Y&& y) } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp index 8c45d2b175..fee7f7779b 100644 --- a/include/ck_tile/host/arg_parser.hpp +++ b/include/ck_tile/host/arg_parser.hpp @@ -13,6 +13,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /* * a host side utility, arg parser for, either @@ -234,3 +237,4 @@ class ArgParser std::vector keys; }; } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index d26686ec37..ddeb3ad781 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -17,6 +17,9 @@ #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/ranges.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -859,3 +862,4 @@ auto get_default_stride(std::size_t row, return stride; } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 957cf7ab8f..987704e433 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -41,7 +41,8 @@ enum struct TailNumber } // namespace ck_tile -inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck_tile::GemmPipelineScheduler& s) { switch(s) { @@ -53,7 +54,8 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch return os; } -inline std::ostream& operator<<(std::ostream& os, const ck_tile::TailNumber& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck_tile::TailNumber& s) { switch(s) { diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 28674554a1..fd698ee340 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -9,6 +9,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + class ProfilerOperationRegistry final { ProfilerOperationRegistry() = default; @@ -83,3 +86,4 @@ class ProfilerOperationRegistry final ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation) \ _Pragma("clang diagnostic pop") // clang-format on +#pragma clang diagnostic pop diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp index 134d2e5f37..689a7a799a 100644 --- a/test/position_embedding/position_embedding.cpp +++ b/test/position_embedding/position_embedding.cpp @@ -9,6 +9,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #ifndef TEST_ALIBI_VERBOSE #define TEST_ALIBI_VERBOSE 0 #endif @@ -213,3 +216,4 @@ int main() // clang-format on return rtn ? 0 : -1; } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp index f8c196e32a..b0d8445c16 100644 --- a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -13,6 +13,9 @@ #include "ck_tile/host.hpp" #include "gemm_multi_d_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-seggestions" + // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts @@ -230,3 +233,4 @@ void gemm_multi_d_host_reference(int verify, a_m_k, b_k_n, {d0_m_n, d1_m_n}, c_m_n_host_result); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 748fe581d3..41ccc4a01b 100644 --- a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -7,6 +7,9 @@ #include "ck_tile/host.hpp" #include "gemm_preshuffle_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + //[TODO] Move parts of this File to commons enum class Metric { @@ -234,3 +237,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp index 7c8df32ad8..11aef4c251 100644 --- a/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp @@ -13,6 +13,8 @@ #include "ck_tile/host.hpp" #include "gemm_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts @@ -240,3 +242,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp index 45beb0acce..d877f174b2 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -17,6 +17,9 @@ // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + enum class Metric { LATENCY = 0, @@ -199,3 +202,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } } +#pragma clang diagnostic pop From 301eb5cf083a03382f7dc69b3277038658a10b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= <153429852+zsotakal@users.noreply.github.com> Date: Mon, 2 Feb 2026 22:58:11 +0100 Subject: [PATCH 3/9] Implement device grouped gemm fixed nk multi abd for rdna4 (#3619) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * device struct implementation * added xdl grouped multi abd fixed nk testing * wmma implementation fixed * avoid unnecessary device mem allocation and code cleanups * cleanup instances definitions * wmma examples added * code cleanups * fix clang format * typo and compilation fixes related to reference gemm * fix compilation error due to std::remove_cvref_t * added missing hip_check_error includes * correction to example instances * review commentes addressed * removed split-k from testing * code formatting --------- Co-authored-by: Zoltán Lakatos Co-authored-by: illsilin_amdeng --- ...grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp | 2 + .../59_grouped_gemm_multi_ABD/CMakeLists.txt | 8 + ...m_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | 400 ++++++++ ...gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 396 ++++++++ ...mm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 27 +- ..._gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp | 33 +- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 899 ++++++++++++++++++ ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 27 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 + include/ck/utility/tuple_helper.hpp | 9 + .../cpu/reference_gemm_multi_abd.hpp | 194 ++++ .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 295 +++++- .../CMakeLists.txt | 6 +- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 144 +++ ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 144 +++ ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 144 +++ ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 2 + .../profiler/profile_gemm_multi_abd_impl.hpp | 88 +- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 534 +++++++++++ test/grouped_gemm/CMakeLists.txt | 6 + .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 256 +++++ 24 files changed, 3517 insertions(+), 120 deletions(-) create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp create mode 100644 test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp index 0766373465..e6e2137bea 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -15,6 +15,8 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/host_utility/hip_check_error.hpp" + using ::ck::hip_check_error; template diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index 4155e0a344..d7ff58705c 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -8,3 +8,11 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) + +add_custom_target(example_grouped_gemm_wmma_multi_abd) + +add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp) +add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8) \ No newline at end of file diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp new file mode 100644 index 0000000000..4eab6cfce2 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp @@ -0,0 +1,400 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; + + std::vector> a0_tensors; + std::vector> b_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a0_tensors.reserve(group_count); + b_tensors.reserve(group_count); + b0_tensors.reserve(group_count); + b1_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b1_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); + + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + + sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(&argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); + } + } + + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + c_host_tensors[i], + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000..c494e45bfb --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp @@ -0,0 +1,396 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/utility/scheduler_enum.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; +using Scale = ck::tensor_operation::element_wise::Scale; +using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; + +using A0DataType = F16; +using A1DataType = F32; +using AsDataType = ck::Tuple; +using B0DataType = F16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using A1Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; + + std::vector> a0_tensors; + std::vector> a1_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a0_tensors.reserve(group_count); + a1_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + a1_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + a1_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 2; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + a1_tensors_device.emplace_back(std::make_unique( + sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + {1, 1}, + {problem_size.stride_Bs[i]}, + {0}, + 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer(), + a1_tensors_device[i]->GetDeviceBuffer()}, + std::array{b_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i], + problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i]}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + constexpr float scale = 1.f; + auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(&argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); + } + } + + c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), + e_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + e_host_tensors[i], + PassThrough{}, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(64); + problem_size.Ks.push_back(64); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 28b3fa9213..dfb20777bc 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -20,6 +20,8 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/host_utility/hip_check_error.hpp" + using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -220,8 +222,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back( - std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -232,21 +234,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back( - std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); - - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), - a0_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(A0DataType)); - - b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), - b0_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(B0DataType)); - - b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), - b1_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(B1DataType)); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -398,7 +391,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 032842b9eb..82c2e17308 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -20,6 +20,8 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/host_utility/hip_check_error.hpp" + using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -47,9 +49,9 @@ using B0DataType = F16; using BsDataType = ck::Tuple; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = F32; +using D0DataType = F16; using DsDataType = ck::Tuple; -using EDataType = F32; +using EDataType = F16; using A0Layout = Row; using A1Layout = Row; @@ -210,11 +212,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back( - std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - a1_tensors_device.emplace_back( - std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); + a1_tensors_device.emplace_back(std::make_unique( + sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); b_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -222,19 +224,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back( - std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), - a0_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(A0DataType)); - - a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), - a1_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(A1DataType)); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), - b_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(B0DataType)); + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -394,7 +389,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp new file mode 100644 index 0000000000..10e604de60 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -0,0 +1,899 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if defined(__gfx11__) || defined(__gfx12__) + __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>()]; + + const index_t KBatch = 1; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + auto karg = gemm_desc_ptr[group_id]; + + if(karg.M == 0 || karg.N == 0 || karg.K == 0) + return; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) +#endif + { + + typename GridwiseGemm::Problem problem(karg.M, + karg.N, + karg.K, + karg.StrideAs, + karg.StrideBs, + karg.StrideDs, + karg.StrideE, + KBatch); + + const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(karg.p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + static_cast(karg.p_e_grid), + p_shared, + problem, + block_2_etile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } + } +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = grid_size_grp; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + : public DeviceGroupedGemmMultiABDFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // Note: Pass multiple layout but then using only the first one + // This is to replicate xdl functionality but it should be extended + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + typename uniform_sequence_gen::type, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + // TODO: Block to tile mappings could potentially moved out to avoid code duplications between + // different device implementations. + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1 + using KernelArgument = typename GridwiseGemm::Argument; + + using GemmTransKernelArg = + GroupedGemmMultiABDKernelArgument; + + static constexpr bool CalculateHasMainKBlockLoop(const GemmTransKernelArg& karg, + index_t k_batch) + { + index_t k_grain = k_batch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / k_batch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + // Client is expected to manually copy the kernel arguments to the device therefore there is + // no point in setting tensor device pointers for the argument structure. + Argument(std::vector>&, + std::vector>&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op, + index_t kbatch) + : group_count_{ck::type_convert(gemm_descs.size())}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + grouped_gemm_kernel_args_dev{nullptr}, + gemm_kernel_host_args_{nullptr}, + grid_size_{0}, + k_batch_{kbatch} + { + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t fixed_N = gemm_descs[0].N_; + const index_t fixed_K = gemm_descs[0].K_; + + for(std::size_t g = 0; g < gemm_descs.size(); g++) + { + const index_t M = gemm_descs[g].M_; + const index_t N = gemm_descs[g].N_; + const index_t K = gemm_descs[g].K_; + + if(M != sum_of_m || N != fixed_N || K != fixed_K) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + // pointer + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + + static_for<0, NumATensor, 1>{}([&](auto i) { p_as_grid[i] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { p_bs_grid[i] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid[i] = nullptr; }); + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + + const index_t StrideE = gemm_descs[g].stride_C_; + + if(gemm_descs[g].stride_As_.size() != NumATensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); + } + + static_for<0, NumATensor, 1>{}( + [&](auto j) { StrideAs[j] = gemm_descs[g].stride_As_[j]; }); + + if(gemm_descs[g].stride_Bs_.size() != NumBTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); + } + + static_for<0, NumBTensor, 1>{}( + [&](auto j) { StrideBs[j] = gemm_descs[g].stride_Bs_[j]; }); + + if(gemm_descs[g].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + static_for<0, NumDTensor, 1>{}( + [&](auto j) { StrideDs[j] = gemm_descs[g].stride_Ds_[j]; }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, AverM, N, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + const index_t block_start = grid_size_; + + grid_size_ += grid_size_grp_; + + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = GemmTransKernelArg({p_as_grid, + p_bs_grid, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE}); + + gemm_desc_kernel_arg_.emplace_back(std::move(karg)); + + group_id++; + } + } + + void UpdateKBatch(index_t) {} + + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + void* gemm_kernel_host_args_; + index_t grid_size_; + index_t grid_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); + } + + if(arg.k_batch_ != 1) + { + throw std::runtime_error("Split K functionality is not supported for wmma multi " + "abd fixed nk implementation."); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto e_global_memory_operation_) { + const auto kernel = kernel_grouped_gemm_wmma_fixed_nk; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto Set = InMemoryDataOperationEnum::Set; + ave_time = launch_kernel(integral_constant{}); + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return RunImp(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by + // vector load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + for(index_t i = 0; i < arg.group_count_; i++) + { + if(CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i], arg.k_batch_) != true) + { + supported = false; + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Wmma_Fixed_Nk" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetElementwiseOps(Argument& arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + arg.a_element_op_ = a_element_op; + arg.b_element_op_ = b_element_op; + arg.c_element_op_ = c_element_op; + } + + // polymorphic + void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) const override + { + + SetElementwiseOps( + *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * + sizeof(GroupedGemmMultiABDKernelArgument); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } + + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index fb4e01b961..36e66017c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -605,7 +605,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK if(arg.grouped_gemm_kernel_args_dev == nullptr) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); } float ave_time = 0; @@ -688,6 +688,11 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + // Split-K autodeduction is not supported if(arg.k_batch_ < 1) { @@ -720,6 +725,26 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } } + for(index_t i = 0; i < arg.group_count_; i++) + { + if(get_warp_size() == 64) + { + if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + true) + { + supported = false; + } + } + else + { + if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + true) + { + supported = false; + } + } + } + return supported; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 7653724b21..311a1c0bf4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -696,7 +696,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK& ty); } +template +auto concat_tuple_of_reference(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp new file mode 100644 index 0000000000..2d766e621b --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp @@ -0,0 +1,194 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/utility/functional4.hpp" +#include "ck/utility/tuple_helper.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmMultiABD : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const AsTensorTuple& as_m_k, + const BsTensorTuple& bs_k_n, + const DsTensorTuple& ds_m_n, + Tensor& e_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : as_m_k_{as_m_k}, + bs_k_n_{bs_k_n}, + ds_m_n_{ds_m_n}, + e_m_n_{e_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const AsTensorTuple& as_m_k_; + const BsTensorTuple& bs_k_n_; + const DsTensorTuple& ds_m_n_; + Tensor& e_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmMultiABD::Argument; + + float Run(const Argument& arg) + { + static constexpr index_t NumATensor = AsTensorTuple::Size(); + static constexpr index_t NumBTensor = BsTensorTuple::Size(); + static constexpr index_t NumDTensor = DsTensorTuple::Size(); + + const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0]; + const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1]; + const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1]; + + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, + Number{}); + auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); + unpack(arg.a_element_op_, data_refs); + } + } + + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, + Number{}); + auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); + unpack(arg.b_element_op_, data_refs); + } + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, + Number{}); + auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); + unpack(arg.cde_element_op_, data_refs); + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const AsTensorTuple& as_m_k, + const BsTensorTuple& bs_k_n, + const DsTensorTuple& ds_m_n, + Tensor& e_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmMultiABD" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 6d97ec3a05..0879bea4ea 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" namespace ck { namespace tensor_operation { @@ -21,6 +20,7 @@ using Multiply = ck::tensor_operation::element_wise::Multiply; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +#if defined(CK_USE_XDL) // RRR void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( std::vector, @@ -179,6 +179,167 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan PassThrough, Multiply, PassThrough>>>& instances); +#endif + +#if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// RCR +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// CRR +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); +#endif // CK_USE // GEMM + Add + Gelu template > op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -246,6 +408,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -289,6 +483,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -317,6 +512,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -360,6 +587,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -388,6 +616,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -431,6 +691,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -459,6 +720,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index 9d9a0e691c..fc60f48727 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -1,13 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp + + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..a29f8513d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple, + Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple, + Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple<>, + Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple<>, + Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..2eaaaf009a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple, + Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple, + Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple<>, + Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple<>, + Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..3320b4afa6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple, + Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple, + Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple<>, + Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple<>, + Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index 23e3b7f511..6e72d379d0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -61,6 +61,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that +// portion of the instances are failing. As a workaround these have been commented out. template , S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 0560f159fc..5eedb8b5ee 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -61,6 +61,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that +// portion of the instances are failing. As a workaround these have been commented out. template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 95365c82e7..7d1fcb5552 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -61,6 +61,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that +// portion of the instances are failing. As a workaround these have been commented out. template -auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) -{ - return ck::unpack2( - [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, - tx, - ty); -} - template c_m_n({M, N}); - using AComputeType = typename std::conditional<(NumATensor > 1), EDataType, remove_cvref_t>>::type; - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(a_element_op, data_refs); - } - } - using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(b_element_op, data_refs); - } - } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument( + as_m_k, bs_k_n, ds_m_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - // compulsory - auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); - // optional (if multiple Ds) - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(cde_element_op, data_refs); - } - } } std::array as_device_buf; diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp new file mode 100644 index 0000000000..eea72b324d --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -0,0 +1,534 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" + +namespace ck { +namespace profiler { + +template +auto reserveVector(std::size_t size) +{ + std::vector vec; + vec.reserve(size); + return vec; +} + +template +bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideE, + const std::vector& kbatch_list = {1}, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + const std::size_t group_count = Ms.size(); + const int sum_of_m = std::accumulate(Ms.begin(), Ms.end(), 0); + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + if(group_count != Ns.size() || group_count != Ks.size() || group_count != StrideAs.size() || + group_count != StrideBs.size() || (NumDTensor > 0 && group_count != StrideDs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideAs/Bs/Ds/E size\n"); + } + + auto generateInputTupleA = [&](std::size_t g) { + if constexpr(NumATensor == 0) + { + static_assert("Gemm problem should have at least 1 A tensor."); + } + else + { + using ALayout = remove_cvref_t{}, AsLayout>>; + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ms[g], Ks[g], StrideAs[g], ALayout{})); + }, + Number{}); + } + }; + auto generateInputTupleB = [&](std::size_t g) { + if constexpr(NumBTensor == 0) + { + static_assert("Gemm problem should have at least 1 B tensor."); + } + else + { + using BLayout = remove_cvref_t{}, BsLayout>>; + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ks[g], Ns[g], StrideBs[g], BLayout{})); + }, + Number{}); + } + }; + auto generateInputTupleD = [&](std::size_t g) { + if constexpr(NumDTensor == 0) + { + return ck::Tuple<>(); + } + else + { + using DLayout = remove_cvref_t{}, DsLayout>>; + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ms[g], Ns[g], StrideDs[g], DLayout{})); + }, + Number{}); + } + }; + + using AsTensorTuple = decltype(generateInputTupleA(0)); + using BsTensorTuple = decltype(generateInputTupleB(0)); + using DsTensorTuple = decltype(generateInputTupleD(0)); + + auto g_as_m_k = reserveVector(group_count); + auto g_bs_k_n = reserveVector(group_count); + auto g_ds_m_n = reserveVector(group_count); + auto g_e_m_n_host_results = reserveVector>(group_count); + auto g_e_m_n_device_results = reserveVector>(group_count); + + for(std::size_t g = 0; g < group_count; g++) + { + auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); + auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); + auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); + + g_e_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); + g_e_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << g << std::endl; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; + }); + std::cout << "e_m_n: " << g_e_m_n_device_results[g].mDesc << std::endl; + } + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + break; + default: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + }); + } + } + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector> g_as_device_buf(group_count); + std::vector> g_bs_device_buf(group_count); + std::vector> g_ds_device_buf(group_count); + std::vector g_e_device_buf(group_count); + + std::vector> g_as_device_view(group_count); + std::vector> g_bs_device_view(group_count); + std::vector> g_ds_device_view(group_count); + std::vector g_e_device_view(group_count); + + auto g_gemm_descs = reserveVector(group_count); + + auto grouped_gemm_kernel_args_host = + reserveVector>( + group_count); + + for(std::size_t g = 0; g < group_count; g++) + { + std::array as_stride; + std::array bs_stride; + std::array ds_stride; + + auto& as_m_k = g_as_m_k[g]; + auto& as_device_buf = g_as_device_buf[g]; + auto& as_device_view = g_as_device_view[g]; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_device_buf[i] = std::make_unique(sizeof(ADataType) * Ms[g] * Ks[g]); + as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); + as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); + as_stride[i] = StrideAs[g]; + }); + + auto& bs_k_n = g_bs_k_n[g]; + auto& bs_device_buf = g_bs_device_buf[g]; + auto& bs_device_view = g_bs_device_view[g]; + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_device_buf[i] = std::make_unique(sizeof(BDataType) * Ks[g] * Ns[g]); + bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data()); + bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer(); + bs_stride[i] = StrideBs[g]; + }); + + auto& ds_m_n = g_ds_m_n[g]; + auto& ds_device_buf = g_ds_device_buf[g]; + auto& ds_device_view = g_ds_device_view[g]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_device_buf[i] = std::make_unique(sizeof(DDataType) * Ms[g] * Ns[g]); + ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); + ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); + ds_stride[i] = StrideDs[g]; + }); + + g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); + g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); + + g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ + sum_of_m, + Ns[g], + Ks[g], + std::vector(as_stride.begin(), as_stride.end()), + std::vector(bs_stride.begin(), bs_stride.end()), + std::vector(ds_stride.begin(), ds_stride.end()), + StrideE[g]}); + + tensor_operation::device:: + GroupedGemmMultiABDKernelArgument + kernelArg{as_device_view, + bs_device_view, + ds_device_view, + g_e_device_view[g], + Ms[g], + Ns[g], + Ks[g], + as_stride, + bs_stride, + ds_stride, + StrideE[g]}; + + grouped_gemm_kernel_args_host.push_back(std::move(kernelArg)); + } + + using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + if(do_verification) + { + using AComputeType = + typename std::conditional<(NumATensor > 1), + EDataType, + remove_cvref_t>>::type; + + using BComputeType = + typename std::conditional<(NumBTensor > 1), + EDataType, + remove_cvref_t>>::type; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + for(std::size_t i = 0; i < group_count; i++) + { + auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i], + g_bs_k_n[i], + g_ds_m_n[i], + g_e_m_n_host_results[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs); + + if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Gemm incompatible with runtime set parameters. Skipping..." + << std::endl; + } + + continue; + } + + DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem grouped_gemm_kernel_args_dev( + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_host.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + for(const auto kbatch_curr : kbatch_list) + { + gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + for(std::size_t g = 0; g < group_count; g++) + { + g_e_device_buf[g]->SetZero(); + } + + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + if(do_verification) + { + bool instance_pass = true; + for(std::size_t g = 0; g < group_count; g++) + { + g_e_device_buf[g]->FromDevice( + g_e_m_n_device_results[g].mData.data(), + g_e_m_n_device_results[g].mDesc.GetElementSize() * sizeof(EDataType)); + + instance_pass = + instance_pass && ck::utils::check_err(g_e_m_n_device_results[g], + g_e_m_n_host_results[g]); + + if(do_log) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "a[" << g << "]: ", g_as_m_k[g](i).mData, ",") + << std::endl; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "b[" << g << "]: ", g_bs_k_n[g](i).mData, ",") + << std::endl; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "d[" << g << "]: ", g_ds_m_n[g](i).mData, ",") + << std::endl; + }); + LogRangeAsType( + std::cout << "e_device: ", g_e_m_n_device_results[g].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e_host : ", g_e_m_n_host_results[g].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t g = 0; g < group_count; g++) + { + flop += std::size_t(2) * Ms[g] * Ns[g] * Ks[g]; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + num_btype += sizeof(ADataType) * Ms[g] * Ks[g]; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + num_btype += sizeof(BDataType) * Ks[g] * Ns[g]; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + num_btype += sizeof(DDataType) * Ms[g] * Ns[g]; + }); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 450950cbd6..bc79c85e59 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,6 +18,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() + + add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp new file mode 100644 index 0000000000..610e7f2b77 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -0,0 +1,256 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/ck.hpp" +#include "ck/utility/type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" + +#include "gtest/gtest.h" + +static ck::index_t param_mask = 0xffffff; +static ck::index_t instance_index = -1; + +using FP32 = float; +using FP16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Add = ck::tensor_operation::element_wise::Add; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu> +>; +// clang-format on + +template +class TestGroupedGemmMultiABDFixedNK : public testing::Test +{ + protected: + using AsDataType = std::tuple_element_t<0, Tuple>; + using BsDataType = std::tuple_element_t<1, Tuple>; + using DsDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using AccDataType = float; + using AsLayout = std::tuple_element_t<4, Tuple>; + using BsLayout = std::tuple_element_t<5, Tuple>; + using DsLayout = std::tuple_element_t<6, Tuple>; + using ELayout = std::tuple_element_t<7, Tuple>; + using AElementOp = PassThrough; + using BElementOp = Multiply; + using CDEElementOp = std::tuple_element_t<8, Tuple>; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // integer value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + static constexpr int n_warmup_ = 0; + static constexpr int n_iter_ = 1; + + std::vector k_batches_ = {1}; + + private: + template + void SetStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if(std::is_same_v) + { + for(const auto c : cols) + { + strides.emplace_back(c); + } + } + else if(std::is_same_v) + { + for(const auto r : rows) + { + strides.emplace_back(r); + } + } + } + + template + void SetTupleStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if constexpr(Layouts::Size() > 0) + { + // As of now multi ABD implementation supports only tensors with matching layouts. + using Layout = ck::remove_cvref_t{}, Layouts>>; + SetStrides(strides, rows, cols); + } + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs = {}, + const std::vector& StrideBs = {}, + const std::vector& StrideDs = {}, + const std::vector& StrideE = {}) + { + std::vector stride_as = StrideAs; + std::vector stride_bs = StrideBs; + std::vector stride_ds = StrideDs; + std::vector stride_e = StrideE; + + if(stride_as.empty()) + { + SetTupleStrides(stride_as, Ms, Ks); + } + if(stride_bs.empty()) + { + SetTupleStrides(stride_bs, Ks, Ns); + } + if(stride_ds.empty()) + { + SetTupleStrides(stride_ds, Ms, Ns); + } + if(stride_e.empty()) + { + SetStrides(stride_e, Ms, Ns); + } + + RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e); + } + + void RunSingle(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideE) + { + bool pass = + ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideE, + k_batches_, + n_warmup_, + n_iter_); + EXPECT_TRUE(pass); + } +}; + +TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes); + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) +{ + const std::vector Ms{3, 4}; + constexpr int N = 8; + constexpr int K = 64; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) +{ + const std::vector Ms{3, 5, 16, 7, 8}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) + { + // Run with default arguments. + } + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} From 3e777217551c82a47eb9540791fb5542f2704e63 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 3 Feb 2026 02:41:53 +0400 Subject: [PATCH 4/9] feat: add split_k support for block scale gemm bquant mode. (#3653) * WIP: add splitk to bquant * feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types * chore: remove temporary test script * fix: incorrect tile window length for splitted bq tensor window * chore: improve comments * test: add unit tests to cover bquant splitk functionality * fix: conflict resolution by renaming variables --- .../gemm_bquant_quantgrouped_bf8.cpp | 2 +- .../gemm_bquant_quantgrouped_bf8i4.cpp | 2 +- .../gemm_bquant_quantgrouped_fp8.cpp | 2 +- .../gemm_bquant_quantgrouped_fp8i4.cpp | 2 +- .../run_gemm_quant_example.inc | 183 +----------------- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 130 +++++++++++-- .../kernel/grouped_gemm_quant_kernel.hpp | 8 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 11 ++ .../test_gemm_quant_bquant_splitk_decode.cpp | 61 ++++++ .../test_gemm_quant_bquant_splitk_prefill.cpp | 64 ++++++ .../test_gemm_quant_fixtures.hpp | 16 +- 11 files changed, 273 insertions(+), 208 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index a95c0346cf..1520f2c591 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index d2b95d3263..a93fe15a1b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index a8c13c1b3d..39747ff0bc 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 6576b22c03..ed18cd8890 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 540d5725dd..508f3ac8ec 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -215,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - + // Split-K validation is handled by Kernel::IsSupportedArgument + // Split-K is only supported for BQuantGrouped without preshuffle if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -661,182 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 3) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else - { - ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - } - } - else if(init_method == 4) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - } - ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v || - std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } - else if(init_method == 5) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); - } - // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) - for(ck_tile::index_t row = 0; - row < static_cast(aq_tensor_ptr->get_length(0)); - ++row) - { - for(ck_tile::index_t col = 0; - col < static_cast(aq_tensor_ptr->get_length(1)); - ++col) - { - (*aq_tensor_ptr)(row, col) = static_cast(col + 1); - } - } - // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v || - std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 21bd691b49..db86fdbeac 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -380,9 +380,18 @@ struct QuantGemmKernel __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); - const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); + constexpr auto K1 = + GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block + const index_t K_t = amd_wave_read_first_lane( + kargs.k_batch * K1); // amount of K elements consumed if every split-K batch + // performs exactly one "unit" (K1) + const index_t KRead = amd_wave_read_first_lane( + (kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch + // offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4) + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + const index_t b_k_offset_elements = + amd_wave_read_first_lane(k_id * KRead / BPackedSize); if constexpr(std::is_same_v) { @@ -395,11 +404,11 @@ struct QuantGemmKernel if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements); } if(k_id < static_cast(kargs.k_batch - 1)) @@ -410,10 +419,47 @@ struct QuantGemmKernel { splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); } + + // Compute BQ offset for BQuantGrouped mode (non-preshuffle only) + // Note: With the alignment validation in IsSupportedArgument, KRead is always + // a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned. + if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant) + { + using BQuantGroupSize = remove_cvref_t; + // Compute the K offset for this batch (in terms of K elements) + const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); + // Convert K offset to BQ group offset (logical offset in K/kK dimension) + bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK); + + // BQ tensor layout: + // RowMajor: [K/kK, N/kN] with stride [N/kN, 1] + // ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1] + if constexpr(std::is_same_v) + { + // For RowMajor BQ, K is the row dimension + // offset = bq_group_offset * stride_BQ + const index_t stride_bq = + amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN)); + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq); + } + else if constexpr(std::is_same_v) + { + // For ColumnMajor BQ, K is the column dimension + // offset = bq_group_offset + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); + } + } + else + { + bq_group_offset = 0; + bq_k_split_offset = 0; + } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension) + index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride) index_t splitted_k; }; @@ -805,10 +851,13 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, const QuantGemmKernelArgs& kargs, + const index_t bq_group_offset, const index_t i_m, const index_t i_n) { // Step 1: Create tensor view for BQ + // Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset). + // The dimension should use the remaining K-groups from this offset position. const auto& bq_tensor_view = [&]() { if constexpr(kQuantType == QuantType::RowColQuant) { @@ -850,11 +899,12 @@ struct QuantGemmKernel "ABQuantGrouped requires ColumnMajor BQ layout"); } + using BQuantGroupSize = remove_cvref_t; if constexpr(std::is_same_v) { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), + make_tuple(kargs.QK_B - bq_group_offset, integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), number{}, @@ -865,8 +915,8 @@ struct QuantGemmKernel return make_naive_tensor_view( bq_ptr, make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), - integer_divide_ceil(kargs.K, BQuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1), + kargs.QK_B - bq_group_offset), + make_tuple(kargs.QK_B, 1), number{}, number<1>{}); } @@ -1047,13 +1097,61 @@ struct QuantGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { + // Split-K is supported for BQuantGrouped mode without preshuffle if(kargs.k_batch != 1) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + constexpr bool is_bquant_non_preshuffle = + (kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant; + if constexpr(!is_bquant_non_preshuffle) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 ! " + "Split-K only supported for BQuantGrouped without preshuffle."); + } + return false; + } + else + { + using BQuantGroupSize = remove_cvref_t; + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Constraint 1: KRead must align with B packing requirements. + // For packed data types, multiple K elements are stored in each storage unit. + // Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch. + // If KRead is not divisible by BPackedSize, this division produces a fractional + // offset, making it impossible to start reading from a valid storage unit boundary. + if(KRead % BPackedSize != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!"); + } + return false; + } + + // Constraint 2: KRead must align with quantization group boundaries. + // Each split-K batch reads KRead consecutive K elements. If KRead is not + // a multiple of BQuantGroupSize::kK, the batch will span partial quantization + // groups, requiring split access to a quantization scale. This violates the + // atomic processing requirement where each batch must work with complete groups. + if(KRead % BQuantGroupSize::kK != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Split-K batch size must be aligned with quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by BQuantGroupSize::kK=" + + std::to_string(BQuantGroupSize::kK)); + } + return false; + } } - return false; } if constexpr(std::is_same_v) @@ -1215,7 +1313,10 @@ struct QuantGemmKernel const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + // Note: Pass bq_group_offset so the tensor view dimension reflects + // the remaining K-groups from the split-K offset position. + const auto& bq_block_window = MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -1343,8 +1444,9 @@ struct QuantGemmKernel const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); - const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); - CDataType* c_ptr = static_cast(kargs.c_ptr); + const BQDataType* bq_ptr = + static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index c9e725f5fd..8b77b01e2f 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8e005d588e..2b19053f41 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant split-K tests (no preshuffle) + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode + test_gemm_quant_bquant_splitk_decode.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill + test_gemm_quant_bquant_splitk_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant tests (with PreshuffleB) - split into 5 files add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d test_gemm_quant_bquant_preshuffle_decode_1d.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp new file mode 100644 index 0000000000..ea1a8a1fbb --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKDecodeTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Decode +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + this->run_test_with_validation(32, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + this->run_test_with_validation(32, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + this->run_test_with_validation(32, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=2560 for split_k=5: 2560/5=512=4×128 ✓ + // Also K must be divisible by K_Tile(256)*split_k(5)=1280 + this->run_test_with_validation(32, 128, 2560, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp new file mode 100644 index 0000000000..f4f93dbbb6 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKPrefillTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Prefill +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(2)=256 + this->run_test_with_validation(128, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + // K must be divisible by K_Tile(128)*split_k(3)=384 + this->run_test_with_validation(128, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(4)=512 + this->run_test_with_validation(128, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=1920 for split_k=5: 1920/5=384=3×128 ✓ + // K must be divisible by K_Tile(128)*split_k(5)=640 + this->run_test_with_validation(128, 128, 1920, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a8..ca21bc69b7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; @@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBasetemplate calculate_rtol_atol( - K, 1, max_accumulated_value); + K, k_batch, max_accumulated_value); // Validate results bool pass = ck_tile::check_err(c_m_n_dev_result, @@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{})); EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N - << ", K=" << K; + << ", K=" << K << ", k_batch=" << k_batch; if(!pass) { From f2b9b3a3a65478ff84b5829c5ce2b2d2ab095905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 3 Feb 2026 03:07:33 +0100 Subject: [PATCH 5/9] Fix path to ck tile conv fwd instance generator (#3699) * Fix path to ck tile conv fwd instance generator * fixes --- Jenkinsfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index ca7c4f1d93..80721ea6d3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -581,7 +581,7 @@ def cmake_build(Map conf=[:]){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" } - if ((params.RUN_BUILDER_TESTS || params.RUN_FULL_CONV_TILE_TESTS) && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { + if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args } setup_cmd = conf.get( @@ -1428,8 +1428,8 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ python3 ../experimental/builder/src/generate_instances.py --mode=profiler && \ - ../script/cmake-ck-dev.sh ../ gfx90a && \ + execute_args = """ python3 ../experimental/grouped_convolution_tile_instances/generate_instances.py --mode=profiler && \ + cmake .. --preset dev-gfx90a -D CK_EXPERIMENTAL_BUILDER=ON && \ make -j64 test_grouped_convnd_fwd_tile && \ ./bin/test_grouped_convnd_fwd_tile""" } From 8b56ffb6aea4dd5e3c531912ee6b2258398606ee Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 2 Feb 2026 18:25:56 -0800 Subject: [PATCH 6/9] Fix one more lifetimebound error. (#3703) * fix staging compiler errors * fix clang format --- include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp index 260ebcf4cc..35d987a79a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -63,7 +63,10 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 true> c_thread_buf_; - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + __host__ __device__ constexpr auto& GetCThreadBuffer() [[clang::lifetimebound]] + { + return c_thread_buf_; + } __device__ static auto GetWaveIdx() { From 3f04d27b687365332d2f1654f169444cab192927 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 3 Feb 2026 02:54:18 -0800 Subject: [PATCH 7/9] Remove concrete performance numbers from BUILD_TIME_OPTIMIZATION.md (#3702) Replace specific benchmark numbers with qualitative descriptions since measurements vary across environments and may become outdated. Co-authored-by: Claude --- include/ck/BUILD_TIME_OPTIMIZATION.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck/BUILD_TIME_OPTIMIZATION.md b/include/ck/BUILD_TIME_OPTIMIZATION.md index 94b292b878..045d4bb929 100644 --- a/include/ck/BUILD_TIME_OPTIMIZATION.md +++ b/include/ck/BUILD_TIME_OPTIMIZATION.md @@ -105,7 +105,7 @@ struct generate_identity_sequence generate_tuple(generate_identity_sequence{}, Number{}); ``` -This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction). +This significantly reduces template instantiations for `transform_tensor_descriptor`. **Example: container_concat** @@ -135,7 +135,7 @@ __host__ __device__ constexpr auto container_concat(const Tuple& tx, const } ``` -This reduced `container_concat` instantiations from 186 to 93 (50% reduction). +This reduces `container_concat` template instantiations. **Example: make_uniform_tuple** @@ -192,7 +192,7 @@ __host__ __device__ constexpr index_t find_source_index(Sequence) } ``` -This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%. +This significantly reduces `sequence_map_inverse` instantiations and compile time. ### 4. Use Fold Expressions for Accumulation @@ -222,4 +222,4 @@ __host__ __device__ constexpr auto compute_element_space_size( } ``` -This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%. +This reduces `calculate_element_space_size` instantiations and compile time. From 8cbd09c84a3010b4b3dbe2604875772363e2396b Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:12:15 -0700 Subject: [PATCH 8/9] [CK_TILE] Stream-K Tile Engine Test Config File Generation (#3662) * Stream-K smoke test config file generation This change converts the stream-k smoke tests to use tile engine. Since the m, n, and k values dependent on the CU count of a device, the configs are generated during the Configuration Phase. * Compute GEMM reference on GPU * Remove redundant Stream-K tests Removing redundant tests that are now run via tile engine. * Fix relative and absolute tolerance calculation This change updates the Stream-K tile engine interface to ensure that num_wgs_per_tile is propaged and passed into the compare_results function to calculate the rel and abs tolerance. Before, split-k was used, which is incorrect for Stream-K since the split-k value is always 1. * Cleanup imports, types, and other misc items This commit makes the following changes: - Uses Typing module for nested type hints - Uses quotes around cu_count_arg argument in generate_configs.cmake in if statements - Adds explicit include for tuple in test_gemm_streamk_simple.cpp - Adds a type for the tiles argument in argparser to check argument validity * Use CU count as return value for better parsing * Add reduction tests for bf16, fp8, and bf8 --- test/ck_tile/gemm_streamk/CMakeLists.txt | 14 - .../test_gemm_streamk_bf16_nonpersistent.cpp | 17 -- .../test_gemm_streamk_bf16_persistent.cpp | 17 -- .../test_gemm_streamk_bf8_nonpersistent.cpp | 17 -- .../test_gemm_streamk_bf8_persistent.cpp | 17 -- .../test_gemm_streamk_fp16_nonpersistent.cpp | 17 -- .../test_gemm_streamk_fp16_persistent.cpp | 17 -- .../test_gemm_streamk_fp16_reduction.cpp | 17 -- .../test_gemm_streamk_fp8_nonpersistent.cpp | 17 -- .../test_gemm_streamk_fp8_persistent.cpp | 17 -- .../test_gemm_streamk_reduction_cases.inc | 88 ------ .../test_gemm_streamk_smoke_cases.inc | 47 --- .../gemm_streamk/test_gemm_streamk_types.hpp | 8 - .../gemm_streamk_tile_engine/CMakeLists.txt | 54 ++-- .../gemm_streamk_tile_engine/README.md | 20 +- .../configs/simple_test_config.json | 35 --- .../gemm_streamk_tile_engine/cu_count.cpp | 44 +++ .../generate_configs.cmake | 103 +++++++ .../generate_configs.py | 277 ++++++++++++++++++ .../test_gemm_streamk_simple.cpp | 51 ++-- .../gemm_streamk_instance_builder.py | 11 +- .../gemm_streamk/gemm_streamk_profiler.hpp | 23 +- 22 files changed, 522 insertions(+), 406 deletions(-) delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp delete mode 100644 test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc delete mode 100644 test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc delete mode 100644 test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json create mode 100644 test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp create mode 100644 test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake create mode 100644 test/ck_tile/gemm_streamk_tile_engine/generate_configs.py diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 1390e5ee07..2fac12ebe4 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,19 +23,6 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - add_gtest_executable(test_ck_tile_streamk_reduction - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp - test_gemm_streamk_util.cpp) - add_gtest_executable(test_ck_tile_streamk_smoke - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp - test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_extended ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp @@ -46,7 +33,6 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp test_gemm_streamk_util.cpp) - target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target") diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp deleted file mode 100644 index 95117b6f0d..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp deleted file mode 100644 index 5e0705ab29..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp deleted file mode 100644 index 21e447af29..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp deleted file mode 100644 index 62b7767a69..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp deleted file mode 100644 index fc18b9ebf7..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp deleted file mode 100644 index 8756da4ad8..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp deleted file mode 100644 index bcd4583da2..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction - -TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction); - -#include "test_gemm_streamk_reduction_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp deleted file mode 100644 index 58dca5ca1d..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp deleted file mode 100644 index 1d1e1e31ec..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc deleted file mode 100644 index 66c3e3b5e9..0000000000 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu; - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu; - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 4; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 4; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 3; - ck_tile::index_t N = N_Tile * 7; - ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 3; - ck_tile::index_t N = N_Tile * 7; - ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); -} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc deleted file mode 100644 index 4bd6e9d973..0000000000 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase) -{ - ck_tile::index_t M = 256; - ck_tile::index_t N = 256; - ck_tile::index_t K = 256; - - this->Run(M, N, K); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - // For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This - // assumes tile sizes are large enough such that occupancy is 1. - ck_tile::index_t M = M_Tile * num_cu; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile; - - this->Run(M, N, K); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - // For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along - // the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu - // macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy - // is 1. - ck_tile::index_t M = M_Tile * 2; - ck_tile::index_t N = N_Tile * 2; - ck_tile::index_t K = K_Tile * num_cu; - - this->Run(M, N, K); -} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index ece313b8aa..efb7416580 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -33,14 +33,6 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types< std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> >; -using KernelTypesStreamKFp16Reduction = ::testing::Types< -// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>; - using KernelTypesStreamKBf16Persistent = ::testing::Types< std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, diff --git a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt index 664866d458..8f9bd39886 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt @@ -1,6 +1,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +include(generate_configs.cmake) + # ============================================================================ # GEMM Tile Engine Unit Tests # @@ -87,7 +89,7 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) endif() - message(STATUS " Created test target: ${target_name}") + message(DEBUG " Created test target: ${target_name}") endfunction() # ============================================================================ @@ -101,12 +103,12 @@ endfunction() # layout - Matrix layout (rcr, rrr, ccr, crr) # config_name - Configuration file name without .json extension # ============================================================================ -function(build_gemm_test_targets datatype layout config_name) +function(build_gemm_test_targets datatype layout config_name configs_dir_path) set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") # Locate and validate configuration file set(config_filename "${config_name}.json") - set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") + set(json_blob "${configs_dir_path}/${config_filename}") if(NOT EXISTS ${json_blob}) message(WARNING "Test config file not found: ${json_blob}") @@ -137,11 +139,11 @@ function(build_gemm_test_targets datatype layout config_name) # Verify kernel list file was generated if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) - message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") + message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") return() endif() - message(STATUS "Building tests for ${datatype}_${layout}_${config_name}") + message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}") # STEP 2a: Extract test parameters from config set(test_params_file "${working_path}/test_params.hpp") @@ -230,7 +232,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") # GPU architecture filtering - only build tests for supported architectures set(GEMM_TEST_GPU_TARGETS "") -set(DESIRED_TARGETS "gfx90a;gfx942") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) @@ -241,7 +243,7 @@ endforeach() # Early exit if no compatible GPU architectures are available if(NOT GEMM_TEST_GPU_TARGETS) - message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") return() endif() @@ -282,25 +284,35 @@ set(TEST_LAYOUTS "rcr;rrr;ccr;crr") # Test Target Generation - Datatype-Specific Categories # ============================================================================ -# 1. SIMPLE TEST: Test for basic functionality with data types (fp16, bf16) -# These data types can use larger warp tiles due to smaller memory footprint -set(SIMPLE_TEST_CONFIG "simple_test_config") -set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json") -set(SIMPLE_DATATYPES "fp16;bf16") +# 1. SMOKE TESTS: Test for basic functionality with data types (fp8, bf8, fp16, bf16) +set(SMALL_DATATYPES "fp16;bf16;fp8;bf8") +set(SIXTEEN_BIT_DATATYPES "fp16;bf16") +set(EIGHT_BIT_DATATYPES "fp8;bf8") +set(LARGE_TILES "256,256,32") +set(SMALL_TILES "128,128,32") +set(CONFIG_LIST "") +set(GENERATED_CONFIG_PATH ${CMAKE_CURRENT_BINARY_DIR}/configs) +get_cu_count(CU_COUNT) -if(EXISTS ${SIMPLE_TEST_CONFIG_FILE}) - message(STATUS "Processing simple test config: ${SIMPLE_TEST_CONFIG} (fp16, bf16)") - foreach(datatype IN LISTS SIMPLE_DATATYPES) - # fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) +message(STATUS "Generating and processing configs for Stream-K tests") +foreach(datatype IN LISTS SMALL_DATATYPES) + + if(datatype IN_LIST SIXTEEN_BIT_DATATYPES) + generate_test_configs(${CU_COUNT} ${LARGE_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH}) + else() + generate_test_configs(${CU_COUNT} ${SMALL_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH}) + endif() + + foreach(config IN LISTS CONFIG_LIST) + # testing all layouts (rcr, rrr, ccr, crr) foreach(layout IN LISTS TEST_LAYOUTS) - build_gemm_test_targets("${datatype}" "${layout}" "${SIMPLE_TEST_CONFIG}") + build_gemm_test_targets("${datatype}" "${layout}" "${config}" "${GENERATED_CONFIG_PATH}") endforeach() endforeach() -else() - message(WARNING "Simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}") -endif() +endforeach() + # ============================================================================ message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:") -message(STATUS " - Simple test: fp16/bf16 (all layouts)") +message(STATUS " - Smoke tests: fp16/bf16/fp8/bf8 (all layouts)") diff --git a/test/ck_tile/gemm_streamk_tile_engine/README.md b/test/ck_tile/gemm_streamk_tile_engine/README.md index 4655673852..965342536b 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/README.md +++ b/test/ck_tile/gemm_streamk_tile_engine/README.md @@ -34,17 +34,25 @@ Each test configuration can specify optimized problem sizes in its JSON file: The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. ## Test Configurations +Test configs are generated during the Generation Phase. They are stored under the build directory at test/ck_tile/gemm_streamk_tile_engine/configs. The Compute Unit (CU) count of the device is required to generate the configs. If the Generation Phase occurs on a machine without a GPU or does not contain same GPU architecture on which you will run the tests, you can manually set the CU count using the `CU_COUNT` option: +```bash +# Assuming you are at the root of the repo +cd build +../script/cmake-ck-dev.sh .. gfx90a -G Ninja -DCU_COUNT=100 +``` +You can reference the public whitepaper for your specific GPU to get the appropriate CU count. +If no `CU_COUNT` option is given and no HIP device is found, then the default value of 100 CUs will be used to determine the problem sizes tested. -### 1. **Simple Test** (`simple_test_config.json`) -- **Purpose**: Basic functionality validation for fp16/bf16 data types -- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 +### 1. **Smoke Tests** +- **Purpose**: Basic functionality validation for fp16/bf16/fp8/bf8 data types +- **Config**: 256x256x32 (for bf16/fp16) or 128x128x32 (for bf8/fp8), warp 2x2x1, warp_tile 32x32x16 - **Traits**: compv3 pipeline only -- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp16, bf16 +- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) ## Data Type Support -- ✅ **fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) +- ✅ **fp16, bf16, fp8, bf8**: Fully supported - all layouts (rcr, rrr, ccr, crr) - ❌ **fp64**: Not supported (hardware MFMA limitation) -- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) +- ⏳ **fp32, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) ## Test Result Behavior diff --git a/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json deleted file mode 100644 index 1cfeef7570..0000000000 --- a/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "problem": { - "description": "Basic functionality validation with moderate problem sizes" - }, - "test_params": { - "problem_sizes": [ - {"m": 256, "n": 256, "k": 128, "split_k": 1}, - {"m": 512, "n": 256, "k": 256, "split_k": 1}, - {"m": 256, "n": 512, "k": 256, "split_k": 1} - ] - }, - "tile_config": { - "tile_m": {"values": [128]}, - "tile_n": {"values": [128]}, - "tile_k": {"values": [64]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [16]}, - "warp_tile_n": {"values": [16]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["compv3"]}, - "epilogue": {"values": ["default"]}, - "scheduler": {"values": ["intrawave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false, true]}, - "reduction_strategy": {"values": ["atomic"]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp b/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp new file mode 100644 index 0000000000..88a2d06901 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +/** + * @brief Determines whether a `hipError` is present in the given `error_status` + * @return true if the `error_status` has an error, otherwise false. + */ +bool has_error(const hipError_t& error_status) +{ + if(error_status != hipSuccess) + { + std::cerr << hipGetErrorString(error_status); + return true; + } + + return false; +} + +/** + * @brief Returns the number of Compute Units (CUs) on the given device. + * @return The number of CUs on the device. If an error occurs while querying the device, zero is + * returned. + */ +int get_cu_count() +{ + hipDevice_t dev; + hipDeviceProp_t dev_prop; + + const hipError_t device_status = hipGetDevice(&dev); + + if(has_error(device_status)) + return 0; + + const hipError_t prop_status = hipGetDeviceProperties(&dev_prop, dev); + if(has_error(prop_status)) + return 0; + + return dev_prop.multiProcessorCount; +} + +int main() { return get_cu_count(); } diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake new file mode 100644 index 0000000000..4f18b5dcbe --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake @@ -0,0 +1,103 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(CU_COUNT 0 CACHE STRING "Number of Compute Units on the device") + +# ============================================================================ +# get_cu_count +# +# Returns the CU count for the device. If the given cu_count_arg is a positive +# integer, then the nothing happens. Otherwise, we attempt to query the CU +# count from the device. If the query is unsucessful, the default value of 100 +# is returned. +# +# Parameters: +# cu_count_arg - The starting CU count +# ============================================================================ +function(get_cu_count cu_count_arg) + message(STATUS "Starting query for CU count needed for Stream-K test config generation") + + if(NOT "${${cu_count_arg}}" MATCHES "^[0-9]+$") + message(FATAL_ERROR "The CU count must be a non-negative integer. \ + The given value of ${${cu_count_arg}} is invalid.") + endif() + + if("${${cu_count_arg}}" STREQUAL "0") + + set(CPP_FILE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cu_count.cpp) + set(CPP_EXE_PATH ${CMAKE_CURRENT_BINARY_DIR}/cu_count) + + execute_process( + COMMAND ${CMAKE_HIP_COMPILER} -x hip ${CPP_FILE_PATH} -o ${CPP_EXE_PATH} + RESULT_VARIABLE compile_result + ) + + if (NOT compile_result EQUAL 0) + message(FATAL_ERROR "Compilation of ${CPP_FILE_PATH} failed.\n") + endif() + + execute_process( + COMMAND ${CPP_EXE_PATH} + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_VARIABLE standard_error + RESULT_VARIABLE queried_cu_count + ) + + if (standard_error) + message(STATUS "Error information from attempting to query HIP device and properties:\n" + "${standard_error}") + endif() + + + # Delete the generated cu_count executable + file(REMOVE "${CPP_EXE_PATH}") + + if(queried_cu_count EQUAL 0) + message(WARNING "Unable to query the number of Compute Units. \ + Please use the CU_COUNT CLI option to pass in the \ + number of Compute Units for your target device; otherwise, \ + the default value of 100 will be used.") + set(${cu_count_arg} 100 PARENT_SCOPE) + else() + set(${cu_count_arg} ${queried_cu_count} PARENT_SCOPE) + endif() + + endif() + +endfunction() + +# ============================================================================ +# generate_test_configs +# +# Generate config json files for Stream-K tests +# +# Parameters: +# cu_count_arg - The number of CUs on the device +# tile_sizes - A list of block tile sizes: tile_m,tile_n,tile_k +# datatype - The datatype for which the config is being generated +# config_list - The variable to which the list of config file names are written +# configs_path - Path to the configs directory to which config files are written +# ============================================================================ +function(generate_test_configs cu_count_arg tile_sizes datatype config_list configs_path) + message(STATUS "Generating Stream-K test config files for ${datatype}") + + file(MAKE_DIRECTORY ${configs_path}) + + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py + --cu_count ${cu_count_arg} + --configs_dir_path ${configs_path} + --tiles ${tile_sizes} + --datatype ${datatype} + OUTPUT_VARIABLE CONFIG_LIST + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE script_ret_val + ) + + if (NOT script_ret_val EQUAL 0) + message(FATAL_ERROR "Eror occured during execution of ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py") + endif() + + set(${config_list} ${CONFIG_LIST} PARENT_SCOPE) + +endfunction() diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py new file mode 100644 index 0000000000..ba075a2729 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +from enum import Enum +from typing import Dict, Tuple, List +import argparse +import json +import os +import sys +from dataclasses import dataclass, field, asdict + + +@dataclass +class TileConfig: + """Represents the Tile Config section of a Tile Engine config""" + + tile_m: List[int] = field(default_factory=list) + tile_n: List[int] = field(default_factory=list) + tile_k: List[int] = field(default_factory=list) + warp_m: List[int] = field(default_factory=lambda: [2]) + warp_n: List[int] = field(default_factory=lambda: [2]) + warp_k: List[int] = field(default_factory=lambda: [1]) + warp_tile_m: List[int] = field(default_factory=lambda: [32]) + warp_tile_n: List[int] = field(default_factory=lambda: [32]) + warp_tile_k: List[int] = field(default_factory=lambda: [16]) + + def to_dict(self) -> Dict: + return {k: {"values": v} for k, v in asdict(self).items()} + + +@dataclass +class TraitConfig: + """Represents the Trait Config section of a Tile Engine config""" + + pipeline: List[str] = field(default_factory=lambda: ["compv3"]) + epilogue: List[str] = field(default_factory=lambda: ["cshuffle"]) + scheduler: List[str] = field(default_factory=lambda: ["intrawave"]) + pad_m: List[bool] = field(default_factory=lambda: [False]) + pad_n: List[bool] = field(default_factory=lambda: [False]) + pad_k: List[bool] = field(default_factory=lambda: [False]) + persistent: List[bool] = field(default_factory=lambda: [True, False]) + reduction_strategy: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict: + return {k: {"values": v} for k, v in asdict(self).items()} + + +class TestVariant(Enum): + """Represents a Stream-K test variant""" + + def __init__( + self, + val: int, + reduction_strategy: List[str], + persistent: List[bool], + datatypes: List[str], + description: str, + ): + self._value_ = val + self.reduction_strategy = reduction_strategy + self.persistent = persistent + self.datatypes = datatypes + self.description = description + + ATOMIC_SMOKE = ( + 0, + ["atomic"], + [True, False], + ["fp16", "bf16", "fp8", "bf8"], + "Stream-K atomic smoke tests", + ) + REDUCTION_SMOKE = ( + 2, + ["reduction", "tree"], + [True, False], + ["fp16", "bf16", "fp8", "bf8"], + "Stream-K reduction smoke tests", + ) + EXTENDED = ( + 3, + ["atomic"], + [True, False], + ["fp16", "bf16", "fp8", "bf8"], + "Stream-K extended smoke tests", + ) + + def apply(self, trait_config: TraitConfig) -> None: + """Applies the current test variant's persistent and reduction strategy setting to the given trait_config""" + trait_config.persistent = self.persistent + trait_config.reduction_strategy = self.reduction_strategy + + +@dataclass +class ProblemSize: + """Represents a problem size in a Tile Engine config""" + + m: int + n: int + k: int + variant: TestVariant + split_k: int = 1 + + def to_dict(self) -> Dict: + return {"m": self.m, "n": self.n, "k": self.k, "split_k": self.split_k} + + +@dataclass +class Config: + """Represents a Tile Engine config""" + + description: str + problem_sizes: list[ProblemSize] = field(default_factory=list) + tile_config: TileConfig = field(default_factory=TileConfig) + trait_config: TraitConfig = field(default_factory=TraitConfig) + k_block_per_cu: int = 1 + permute_n: bool = False + + def add_problem_size(self, problem: ProblemSize) -> None: + """Adds the given problem to this config's problem_sizes""" + self.problem_sizes.append(problem) + + def to_dict(self) -> Dict: + config_dict = { + "problem": {"description": f"{self.description}"}, + "test_params": { + "problem_sizes": [ps.to_dict() for ps in self.problem_sizes] + }, + "tile_config": self.tile_config.to_dict(), + "trait_config": self.trait_config.to_dict(), + "k_block_per_cu": self.k_block_per_cu, + "permute_n": self.permute_n, + } + return config_dict + + def write_to_file(self, output_file: str) -> None: + """Writes this configs to the given output_file in a json format""" + with open(output_file, "w") as config_file: + json.dump(self.to_dict(), config_file, indent=4) + config_file.write("\n") + + +def create_problem_sizes( + tile_m: int, tile_n: int, tile_k: int, cu_count: int +) -> List[ProblemSize]: + """Creates and returns a list of problem sizes using the given arguments""" + problem_sizes = [ + ProblemSize(256, 256, 256, TestVariant.ATOMIC_SMOKE), + ProblemSize(tile_m * cu_count, tile_n, tile_k, TestVariant.ATOMIC_SMOKE), + ProblemSize( + tile_m * 2, tile_n * 2, cu_count * tile_k, TestVariant.ATOMIC_SMOKE + ), + ProblemSize(tile_m, tile_n, cu_count * tile_k, TestVariant.REDUCTION_SMOKE), + ProblemSize( + tile_m * 4, + tile_n, + tile_k * cu_count + (25 * tile_k), + TestVariant.REDUCTION_SMOKE, + ), + ProblemSize( + tile_m * 3, + tile_n * 7, + tile_k * cu_count + (30 * tile_k), + TestVariant.REDUCTION_SMOKE, + ), + # TODO: Add this test once we determine how to label tests as regresion with tile engine + # ProblemSize((tile_m * cu_count * 2) + (tile_m * 2), tile_n, 2048, TestVariant.EXTENDED) + ] + + return problem_sizes + + +def write_config_files( + problem_sizes: List[ProblemSize], + configs_dir_path: str, + datatype: str, + tile_sizes: Tuple[int, int, int], +) -> str: + """Writes the given problem_sizes to a config file and returns the names of the config files written to""" + config_names = [] + tile_m, tile_n, tile_k = tile_sizes + tile_config = TileConfig([tile_m], [tile_n], [tile_k]) + + # Create a config for each test variant + for variant in TestVariant: + problem_sizes_filtered = [ps for ps in problem_sizes if ps.variant == variant] + + if (datatype not in variant.datatypes) or len(problem_sizes_filtered) == 0: + continue + + trait_config = TraitConfig() + variant.apply(trait_config) + config_name = f"streamk_{variant.name.lower()}_tests_config_{datatype}" + config_names.append(config_name) + file_path = os.path.join(configs_dir_path, config_name + ".json") + config = Config( + variant.description, problem_sizes_filtered, tile_config, trait_config + ) + config.write_to_file(file_path) + + return config_names + + +def print_config_names(config_file_names: List[str]) -> None: + """Prints given config file names as a single semi-colon separated string""" + print(";".join(config_file_names)) + + +def create_config_files( + cu_count: int, configs_dir_path: str, tile_sizes: int, datatype: str +) -> None: + """Creates Stream-K test config files and prints the file names in a semi-colon-separated list""" + tile_m, tile_n, tile_k = tile_sizes + + problem_sizes = create_problem_sizes(tile_m, tile_n, tile_k, cu_count) + config_names = write_config_files( + problem_sizes, configs_dir_path, datatype, tile_sizes + ) + print_config_names(config_names) + + +def get_args() -> Tuple[int, str, Tuple[int, int, int], str]: + """Returns user provided arguments""" + + def tile_sizes_type(val: str): + sizes = None + parts = val.split(",") + if len(parts) != 3: + raise argparse.ArgumentTypeError( + "--tiles must contain exactly three comma-separated values (m,n,k), e.g. --tiles 256,256,32" + ) + try: + sizes = tuple(int(size) for size in parts) + except ValueError: + raise argparse.ArgumentTypeError( + "--tiles must contain exactly three comma-separated integers (m,n,k), e.g. --tiles 256,256,32" + ) + + return sizes + + parser = argparse.ArgumentParser(description="Create Stream-K test configs") + parser.add_argument( + "--cu_count", required=True, help="Number of Compute Units on the device" + ) + parser.add_argument( + "--configs_dir_path", + required=True, + help="Full path configs directory where config files will be written to", + ) + + parser.add_argument( + "--tiles", + required=True, + type=tile_sizes_type, + help="Block tile sizes for m, n, and k, respectively. Ex: --tiles 256,256,32", + ) + + parser.add_argument( + "--datatype", + choices=["fp16", "bf16", "fp8", "bf8"], + required=True, + help="The datatype for which the config is generated.", + ) + + args = parser.parse_args() + + return (int(args.cu_count), args.configs_dir_path, args.tiles, args.datatype) + + +def main(): + cu_count, configs_dir_path, tile_sizes, datatype = get_args() + create_config_files(cu_count, configs_dir_path, tile_sizes, datatype) + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp index 913e7d8531..1c06d33e77 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp +++ b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp @@ -12,6 +12,7 @@ #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -126,13 +127,18 @@ class StreamKGemmTileEngineTest : public ::testing::TestWithParam 0) << "Kernel name should not be empty"; + + std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; + std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << std::endl; + // Get tensor layouts from generated kernel const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; const CLayout layout_c = CLayout{}; - // Use split_k from test parameters - int split_k = split_k_; + // Calculate tensor strides int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a)); int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b)); int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c)); @@ -144,27 +150,42 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); - ck_tile::HostTensor c_m_n_host_result( + ck_tile::HostTensor c_m_n_dev_ref( ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); // Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine) ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + c_m_n_dev_ref.SetZero(); // Allocate GPU device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes()); // Copy data to device and zero output buffer a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + ref_c_m_n_dev_buf.SetZero(); - // Calculate reference result on host for verification - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_result); + // Calculate reference result on device for verification + ADataType* a_m_k_dev_ref_ptr = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* b_k_n_dev_ref_ptr = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* c_m_n_dev_ref_ptr = static_cast(ref_c_m_n_dev_buf.GetDeviceBuffer()); + ck_tile:: + reference_gemm_gpu( + a_m_k_dev_ref_ptr, + b_k_n_dev_ref_ptr, + c_m_n_dev_ref_ptr, + m_, + n_, + k_, + stride_a_calc, + stride_b_calc, + stride_c_calc); + ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data()); // Create GEMM kernel arguments ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), @@ -188,9 +209,10 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) 1}; // rotating_count // Launch the generated kernel (no timing overhead for fastest execution) + std::tuple launch_result; try { - SelectedKernel::launch(args, stream_config); + launch_result = SelectedKernel::launch(args, stream_config); // Kernel launched successfully if no exception thrown } catch(const std::exception& e) @@ -211,22 +233,13 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); // Verify results using tile_engine's adaptive error thresholds + const ck_tile::index_t num_wgs_per_tile = get<1>(launch_result); bool verification_passed = compare_results( - KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result); + KERNEL_NAME, k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_dev_ref); EXPECT_TRUE(verification_passed) << "GEMM result verification failed"; } -TEST_P(StreamKGemmTileEngineTest, KernelInfo) -{ - // Simple test to verify kernel information is available - EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty"; - - std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; - std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_ - << std::endl; -} - // Use config-specific test parameters (included via compile flags) // CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file INSTANTIATE_TEST_SUITE_P(GemmVerification, diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 877c803d69..c8d6f86ccc 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -481,8 +481,9 @@ struct SelectedKernel {{ AccDataType, TileShape, GemmUniversalTraits>; - - static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ + + static std::tuple launch(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& stream) {{ constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return std::tuple{{time, num_wgs_per_tile}}; }} }}; """ diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp index d168030f97..2a7b07c698 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -22,25 +22,25 @@ class GemmProfiler // Overload for single kernel benchmarking void benchmark(GemmProblem& gemm_problem, - std::function kernel_func) + std::function( + const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func) { - // Create a vector with a single callable that returns both name and time - std::vector(ck_tile::StreamKHostArgs&, - const ck_tile::stream_config&)>> + // Create a vector with a single callable that returns name, time, and num_wgs_per_tile + std::vector( + ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>> callables; callables.push_back( [kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) { - float time = kernel_func(args, stream); - return std::make_tuple(std::string(KERNEL_NAME), time); + auto [time, num_wgs_per_tile] = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile); }); benchmark(gemm_problem, callables); } void benchmark(GemmProblem& gemm_problem, - std::vector( + std::vector( ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; @@ -160,9 +160,9 @@ class GemmProfiler ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - const std::tuple& kernel_run_result) + const std::tuple& kernel_run_result) { - auto [name, avg_time] = kernel_run_result; + auto [name, avg_time, num_wgs_per_tile] = kernel_run_result; auto dp_persistent = SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel"; @@ -196,8 +196,7 @@ class GemmProfiler c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool verified_correct = !setting_.verify_ || - compare( - name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result); + compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result); if(verified_correct) { From 569640dc70bb9175fb1b6664b9a2e0970b7dec78 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:52:14 -0800 Subject: [PATCH 9/9] Revert "Implement device grouped gemm fixed nk multi abd for rdna4 (#3619)" (#3705) This reverts commit 301eb5cf083a03382f7dc69b3277038658a10b2d. --- ...grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp | 2 - .../59_grouped_gemm_multi_ABD/CMakeLists.txt | 8 - ...m_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | 400 -------- ...gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 396 -------- ...mm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 27 +- ..._gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp | 33 +- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 899 ------------------ ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 27 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 - include/ck/utility/tuple_helper.hpp | 9 - .../cpu/reference_gemm_multi_abd.hpp | 194 ---- .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 295 +----- .../CMakeLists.txt | 6 +- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 144 --- ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 144 --- ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 144 --- ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 2 - .../profiler/profile_gemm_multi_abd_impl.hpp | 88 +- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 534 ----------- test/grouped_gemm/CMakeLists.txt | 6 - .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 256 ----- 24 files changed, 120 insertions(+), 3517 deletions(-) delete mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp delete mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp delete mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp delete mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp delete mode 100644 profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp delete mode 100644 test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp index e6e2137bea..0766373465 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -15,8 +15,6 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/host_utility/hip_check_error.hpp" - using ::ck::hip_check_error; template diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index d7ff58705c..4155e0a344 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -8,11 +8,3 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) - -add_custom_target(example_grouped_gemm_wmma_multi_abd) - -add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp) -add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16) - -add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp) -add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8) \ No newline at end of file diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp deleted file mode 100644 index 4eab6cfce2..0000000000 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/host_utility/hip_check_error.hpp" - -using ::ck::DeviceMem; -using ::ck::hip_check_error; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - -using BF16 = ck::bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Add = ck::tensor_operation::element_wise::Add; - -using A0DataType = BF16; -using AsDataType = ck::Tuple; -using B0DataType = I8; -using B1DataType = BF16; -using BsDataType = ck::Tuple; -using AccDataType = F32; -using CShuffleDataType = BF16; -using D0DataType = BF16; -using DsDataType = ck::Tuple; -using EDataType = BF16; - -using A0Layout = Row; -using AsLayout = ck::Tuple; -using B0Layout = Col; -using B1Layout = B0Layout; -using BsLayout = ck::Tuple; -using DsLayout = ck::Tuple; -using ELayout = Row; - -using Multiply = ck::tensor_operation::element_wise::Multiply; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; - -using AElementOp = PassThrough; -using BElementOp = Multiply; -using CDEElementOp = AddFastGelu; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK - // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; - -// clang-format on - -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; - - std::vector stride_As; - std::vector stride_Bs; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - int k_batch = 1; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - // GEMM shape - std::vector gemm_descs; - - gemm_descs.reserve(group_count); - - int sum_of_m = 0; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); - } - }; - - std::vector> a0_tensors; - std::vector> b_tensors; - std::vector> b0_tensors; - std::vector> b1_tensors; - std::vector> d0_tensors; - std::vector> c_host_tensors; - std::vector> c_device_tensors; - - a0_tensors.reserve(group_count); - b_tensors.reserve(group_count); - b0_tensors.reserve(group_count); - b1_tensors.reserve(group_count); - d0_tensors.reserve(group_count); - c_host_tensors.reserve(group_count); - c_device_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, - d0_tensors_device, c_tensors_device; - - a0_tensors_device.reserve(group_count); - b0_tensors_device.reserve(group_count); - b1_tensors_device.reserve(group_count); - d0_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - sum_of_m += problem_size.Ms[i]; - - a0_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); - - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); - b0_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); - b1_tensors.push_back(Tensor( - f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); - - d0_tensors.push_back(Tensor( - f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); - - c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - - std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc - << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc - << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + - sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + - sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + - sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + - sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); - break; - case 2: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - - d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - constexpr ck::index_t NumATensor = 1; - constexpr ck::index_t NumBTensor = 2; - constexpr ck::index_t NumDTensor = 1; - - using GroupedGemmKernelArgument = ck::tensor_operation::device:: - GroupedGemmMultiABDKernelArgument; - - std::vector grouped_gemm_kernel_args_; - grouped_gemm_kernel_args_.reserve(group_count); - - for(int i = 0; i < group_count; i++) - { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - - b0_tensors_device.emplace_back(std::make_unique( - sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - - b1_tensors_device.emplace_back(std::make_unique( - sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - - d0_tensors_device.emplace_back( - std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); - - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); - b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); - d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); - c_tensors_device[i]->SetZero(); - - gemm_descs.push_back( - {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); - - grouped_gemm_kernel_args_.push_back( - {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, - std::array{b0_tensors_device[i]->GetDeviceBuffer(), - b1_tensors_device[i]->GetDeviceBuffer()}, - std::array{d0_tensors_device[i]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - std::array{problem_size.stride_As[i]}, - std::array{problem_size.stride_Bs[i], 0}, - std::array{0}, - problem_size.stride_Cs[i]}); - } - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - std::vector> p_As = {}; - std::vector> p_Bs = {}; - std::vector> p_Ds = {}; - std::vector p_Cs = {}; - - // do GEMM - auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); - gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); - - DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), - grouped_gemm_kernel_args_.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); - gemm.SetKBatch(argument, config.k_batch); - - gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); - - invoker.Run(&argument, StreamConfig{nullptr, false}); - - if(config.time_kernel) - { - float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - for(int n = 0; n < problem_size.Ns[i]; ++n) - { - for(int k = 0; k < problem_size.Ks[i]; ++k) - { - b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); - } - } - - c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), - c_device_tensors[i].mDesc.GetElementSize() * - sizeof(EDataType)); - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], - b_tensors[i], - c_host_tensors[i], - PassThrough{}, - PassThrough{}, - PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < problem_size.Ms[i]; ++m) - { - for(int n = 0; n < problem_size.Ns[i]; ++n) - { - cde_element_op( - c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); - } - } - - pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); - } - } - - return pass; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - problem_size.group_count = 16; - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(32 + rand() % 32); - problem_size.Ns.push_back(1024); - problem_size.Ks.push_back(512); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - } - - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - - return !run_grouped_gemm(problem_size, config); -} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp deleted file mode 100644 index c494e45bfb..0000000000 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp +++ /dev/null @@ -1,396 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" - -#include "ck/utility/scheduler_enum.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/host_utility/hip_check_error.hpp" - -using ::ck::DeviceMem; -using ::ck::hip_check_error; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Add = ck::tensor_operation::element_wise::Add; -using Scale = ck::tensor_operation::element_wise::Scale; -using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; - -using A0DataType = F16; -using A1DataType = F32; -using AsDataType = ck::Tuple; -using B0DataType = F16; -using BsDataType = ck::Tuple; -using AccDataType = F32; -using CShuffleDataType = F32; -using D0DataType = F16; -using DsDataType = ck::Tuple; -using EDataType = F16; - -using A0Layout = Row; -using A1Layout = Row; -using AsLayout = ck::Tuple; -using B0Layout = Col; -using BsLayout = ck::Tuple; -using D0Layout = Row; -using DsLayout = ck::Tuple; -using ELayout = Row; - -using AElementOp = AddScale; -using BElementOp = PassThrough; -using CDEElementOp = Add; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK - // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; -// clang-format on - -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; - - std::vector stride_As; - std::vector stride_Bs; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - int k_batch = 1; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - // GEMM shape - std::vector gemm_descs; - - gemm_descs.reserve(group_count); - - int sum_of_m = 0; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); - } - }; - - std::vector> a0_tensors; - std::vector> a1_tensors; - std::vector> b_tensors; - std::vector> d0_tensors; - std::vector> e_host_tensors; - std::vector> e_device_tensors; - - a0_tensors.reserve(group_count); - a1_tensors.reserve(group_count); - b_tensors.reserve(group_count); - d0_tensors.reserve(group_count); - e_host_tensors.reserve(group_count); - e_device_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, - d0_tensors_device, c_tensors_device; - - a0_tensors_device.reserve(group_count); - a1_tensors_device.reserve(group_count); - b_tensors_device.reserve(group_count); - d0_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - sum_of_m += problem_size.Ms[i]; - a0_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); - a1_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); - d0_tensors.push_back(Tensor( - f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); - e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc - << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc - << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + - sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + - sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + - sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + - sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 2: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - - d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - constexpr ck::index_t NumATensor = 2; - constexpr ck::index_t NumBTensor = 1; - constexpr ck::index_t NumDTensor = 1; - - using GroupedGemmKernelArgument = ck::tensor_operation::device:: - GroupedGemmMultiABDKernelArgument; - - std::vector grouped_gemm_kernel_args_; - grouped_gemm_kernel_args_.reserve(group_count); - - for(int i = 0; i < group_count; i++) - { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - - a1_tensors_device.emplace_back(std::make_unique( - sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - - b_tensors_device.emplace_back(std::make_unique( - sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - - d0_tensors_device.emplace_back( - std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); - - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); - c_tensors_device[i]->SetZero(); - - gemm_descs.push_back({sum_of_m, - problem_size.Ns[i], - problem_size.Ks[i], - {1, 1}, - {problem_size.stride_Bs[i]}, - {0}, - 1}); - - grouped_gemm_kernel_args_.push_back( - {std::array{a0_tensors_device[i]->GetDeviceBuffer(), - a1_tensors_device[i]->GetDeviceBuffer()}, - std::array{b_tensors_device[i]->GetDeviceBuffer()}, - std::array{d0_tensors_device[i]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - std::array{problem_size.stride_As[i], - problem_size.stride_As[i]}, - std::array{problem_size.stride_Bs[i]}, - std::array{0}, - problem_size.stride_Cs[i]}); - } - - constexpr float scale = 1.f; - auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - std::vector> p_As = {}; - std::vector> p_Bs = {}; - std::vector> p_Ds = {}; - std::vector p_Cs = {}; - - // do GEMM - auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); - gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); - - DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), - grouped_gemm_kernel_args_.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); - gemm.SetKBatch(argument, config.k_batch); - - gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); - - invoker.Run(&argument, StreamConfig{nullptr, false}); - - if(config.time_kernel) - { - float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - for(int m = 0; m < problem_size.Ms[i]; ++m) - { - for(int k = 0; k < problem_size.Ks[i]; ++k) - { - a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); - } - } - - c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), - e_device_tensors[i].mDesc.GetElementSize() * - sizeof(EDataType)); - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], - b_tensors[i], - e_host_tensors[i], - PassThrough{}, - b_element_op, - PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < problem_size.Ms[i]; ++m) - { - for(int n = 0; n < problem_size.Ns[i]; ++n) - { - cde_element_op( - e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); - } - } - - pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); - } - } - - return pass; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - problem_size.group_count = 16; - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(32 + rand() % 32); - problem_size.Ns.push_back(64); - problem_size.Ks.push_back(64); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - } - - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - - return !run_grouped_gemm(problem_size, config); -} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index dfb20777bc..28b3fa9213 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -20,8 +20,6 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/host_utility/hip_check_error.hpp" - using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -222,8 +220,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -234,12 +232,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), + b0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), + b1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B1DataType)); - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); - b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -391,7 +398,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 82c2e17308..032842b9eb 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -20,8 +20,6 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/host_utility/hip_check_error.hpp" - using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -49,9 +47,9 @@ using B0DataType = F16; using BsDataType = ck::Tuple; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = F16; +using D0DataType = F32; using DsDataType = ck::Tuple; -using EDataType = F16; +using EDataType = F32; using A0Layout = Row; using A1Layout = Row; @@ -212,11 +210,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); - a1_tensors_device.emplace_back(std::make_unique( - sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + a1_tensors_device.emplace_back( + std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); b_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -224,12 +222,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), + a1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A1DataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -389,7 +394,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp deleted file mode 100644 index 10e604de60..0000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ /dev/null @@ -1,899 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/env.hpp" -#include "ck/host_utility/hip_check_error.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const index_t grid_size_grp, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) -{ -#if defined(__gfx11__) || defined(__gfx12__) - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; - - const index_t KBatch = 1; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - auto karg = gemm_desc_ptr[group_id]; - - if(karg.M == 0 || karg.N == 0 || karg.K == 0) - return; - -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) -#endif - { - - typename GridwiseGemm::Problem problem(karg.M, - karg.N, - karg.K, - karg.StrideAs, - karg.StrideBs, - karg.StrideDs, - karg.StrideE, - KBatch); - - const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); - constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); - constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); - - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t; - p_as_grid_(i) = static_cast(karg.p_as_grid[i]); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t; - p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t; - p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - while(id_local < local_grid_size) - { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_as_grid_, - p_bs_grid_, - p_ds_grid_, - static_cast(karg.p_e_grid), - p_shared, - problem, - block_2_etile_map, - a_element_op, - b_element_op, - cde_element_op, - epilogue_args); - - id_off += grid_size_grp; - id_local += grid_size_grp; - } - } -#else - ignore = gemm_descs_const; - ignore = group_count; - ignore = grid_size_grp; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; -#endif -} - -template -struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK - : public DeviceGroupedGemmMultiABDFixedNK -{ - using DeviceOp = DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK; - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - // Note: Pass multiple layout but then using only the first one - // This is to replicate xdl functionality but it should be extended - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< - ALayout, - BLayout, - DsLayout, - ELayout, - AsDataType, - BsDataType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - GemmSpec, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsExtraN, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - typename uniform_sequence_gen::type, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB, - false, - false>; - - // TODO: Block to tile mappings could potentially moved out to avoid code duplications between - // different device implementations. - - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - - static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1 - using KernelArgument = typename GridwiseGemm::Argument; - - using GemmTransKernelArg = - GroupedGemmMultiABDKernelArgument; - - static constexpr bool CalculateHasMainKBlockLoop(const GemmTransKernelArg& karg, - index_t k_batch) - { - index_t k_grain = k_batch * KPerBlock; - index_t K_split = (karg.K + k_grain - 1) / k_batch; - return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - } - - // Argument - struct Argument : public BaseArgument - { - - Argument(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector& gemm_descs, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) - : Argument(p_As, - p_Bs, - p_Ds, - p_Es, - gemm_descs, - a_element_op, - b_element_op, - c_element_op, - DefaultKBatch) - { - // TODO: use occupancy api to calculate appropriate batch size. - } - - // Client is expected to manually copy the kernel arguments to the device therefore there is - // no point in setting tensor device pointers for the argument structure. - Argument(std::vector>&, - std::vector>&, - std::vector>&, - std::vector&, - std::vector& gemm_descs, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op, - index_t kbatch) - : group_count_{ck::type_convert(gemm_descs.size())}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - grouped_gemm_kernel_args_dev{nullptr}, - gemm_kernel_host_args_{nullptr}, - grid_size_{0}, - k_batch_{kbatch} - { - gemm_desc_kernel_arg_.reserve(group_count_); - - index_t group_id = 0; - - sum_of_m = gemm_descs[0].M_; - const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); - const index_t fixed_N = gemm_descs[0].N_; - const index_t fixed_K = gemm_descs[0].K_; - - for(std::size_t g = 0; g < gemm_descs.size(); g++) - { - const index_t M = gemm_descs[g].M_; - const index_t N = gemm_descs[g].N_; - const index_t K = gemm_descs[g].K_; - - if(M != sum_of_m || N != fixed_N || K != fixed_K) - { - throw std::runtime_error("wrong! M/N/K is not identical"); - } - - a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); - b_mtx_nraw_kraw_.emplace_back(N, K); - - // pointer - std::array p_as_grid; - std::array p_bs_grid; - std::array p_ds_grid; - - static_for<0, NumATensor, 1>{}([&](auto i) { p_as_grid[i] = nullptr; }); - static_for<0, NumBTensor, 1>{}([&](auto i) { p_bs_grid[i] = nullptr; }); - static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid[i] = nullptr; }); - - std::array StrideAs; - std::array StrideBs; - std::array StrideDs; - - const index_t StrideE = gemm_descs[g].stride_C_; - - if(gemm_descs[g].stride_As_.size() != NumATensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); - } - - static_for<0, NumATensor, 1>{}( - [&](auto j) { StrideAs[j] = gemm_descs[g].stride_As_[j]; }); - - if(gemm_descs[g].stride_Bs_.size() != NumBTensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); - } - - static_for<0, NumBTensor, 1>{}( - [&](auto j) { StrideBs[j] = gemm_descs[g].stride_Bs_[j]; }); - - if(gemm_descs[g].stride_Ds_.size() != NumDTensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); - } - - static_for<0, NumDTensor, 1>{}( - [&](auto j) { StrideDs[j] = gemm_descs[g].stride_Ds_[j]; }); - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeDEGridDescriptor_M_N( - AverM, AverM, N, N, StrideE); - - // block-to-e-tile map - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; - - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - - if(group_id * grid_size_grp_ != grid_size_) - { - throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); - } - - const index_t block_start = grid_size_; - - grid_size_ += grid_size_grp_; - - if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) - { - throw std::runtime_error("wrong! block_2_etile_map validation failed"); - } - - auto grouped_block_2_ctile_map = - GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); - - auto karg = GemmTransKernelArg({p_as_grid, - p_bs_grid, - p_ds_grid, - nullptr, - AverM, - N, - K, - StrideAs, - StrideBs, - StrideDs, - StrideE}); - - gemm_desc_kernel_arg_.emplace_back(std::move(karg)); - - group_id++; - } - } - - void UpdateKBatch(index_t) {} - - index_t group_count_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation c_element_op_; - - std::vector gemm_desc_kernel_arg_; - std::vector> a_mtx_mraw_kraw_; - std::vector> b_mtx_nraw_kraw_; - - const void* grouped_gemm_kernel_args_dev; - void* gemm_kernel_host_args_; - index_t grid_size_; - index_t grid_size_grp_; - index_t sum_of_m; - - index_t k_batch_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(arg.grouped_gemm_kernel_args_dev == nullptr) - { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); - } - - if(arg.k_batch_ != 1) - { - throw std::runtime_error("Split K functionality is not supported for wmma multi " - "abd fixed nk implementation."); - } - - float ave_time = 0; - - auto launch_kernel = [&](auto e_global_memory_operation_) { - const auto kernel = kernel_grouped_gemm_wmma_fixed_nk; - - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); - }; - - constexpr auto Set = InMemoryDataOperationEnum::Set; - ave_time = launch_kernel(integral_constant{}); - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return RunImp(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) - { - return false; - } - - bool supported = true; - - // If we use padding we do not support vector loads for dimensions not divisible by - // vector load size. - if constexpr(GemmSpec != GemmSpecialization::Default) - { - // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, - // thus we have to adapt it to the {M,K} or {N,K} layout. - const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; - const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; - - for(index_t i = 0; i < arg.group_count_; ++i) - { - const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); - const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); - - supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); - supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); - } - } - - for(index_t i = 0; i < arg.group_count_; i++) - { - if(CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i], arg.k_batch_) != true) - { - supported = false; - } - } - - return supported; - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector gemm_descs, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) - { - return Argument{ - p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector& gemm_descs, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override - { - return std::make_unique( - p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGroupedGemm_Wmma_Fixed_Nk" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 << ", " - << MPerWmma << ", " - << NPerWmma << ", " - << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector << ", " - << CShuffleMRepeatPerShuffle << ", " - << CShuffleNRepeatPerShuffle << ", " - << getGemmSpecializationString(GemmSpec) - << ">"; - // clang-format on - - return str.str(); - } - - static void SetElementwiseOps(Argument& arg, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) - { - arg.a_element_op_ = a_element_op; - arg.b_element_op_ = b_element_op; - arg.c_element_op_ = c_element_op; - } - - // polymorphic - void SetElementwiseOps(BaseArgument* p_arg, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) const override - { - - SetElementwiseOps( - *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); - } - - static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) - { - arg.grouped_gemm_kernel_args_dev = kernel_args; - } - - // polymorphic - void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override - { - return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); - } - - size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override - { - auto arg = *dynamic_cast(p_arg); - - return arg.group_count_ * - sizeof(GroupedGemmMultiABDKernelArgument); - } - - size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override - { - auto p_arg_ = dynamic_cast(p_arg); - if(p_arg_) - { - return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); - } - else - throw std::runtime_error( - "The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK::Argument structure!"); - } - - void SetWorkSpacePointer(BaseArgument* p_arg, - void* p_workspace, - const StreamConfig& stream_config = StreamConfig{}) const override - { - auto p_arg_ = dynamic_cast(p_arg); - p_arg_->p_workspace_ = p_workspace; - - hip_check_error( - hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); - } - - static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } - - // polymorphic - void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override - { - return SetKBatch(*dynamic_cast(p_arg), k_batch); - } - - void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const - { - Argument* pArg_ = dynamic_cast(p_arg); - if(!pArg_) - { - throw std::runtime_error("Failed to cast argument pointer!"); - } - - pArg_->gemm_kernel_host_args_ = p_host_kernel_args; - std::copy(pArg_->gemm_desc_kernel_arg_.begin(), - pArg_->gemm_desc_kernel_arg_.end(), - static_cast(pArg_->gemm_kernel_host_args_)); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 36e66017c6..fb4e01b961 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -605,7 +605,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK if(arg.grouped_gemm_kernel_args_dev == nullptr) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); } float ave_time = 0; @@ -688,11 +688,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) - { - return false; - } - // Split-K autodeduction is not supported if(arg.k_batch_ < 1) { @@ -725,26 +720,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } } - for(index_t i = 0; i < arg.group_count_; i++) - { - if(get_warp_size() == 64) - { - if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != - true) - { - supported = false; - } - } - else - { - if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != - true) - { - supported = false; - } - } - } - return supported; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 311a1c0bf4..7653724b21 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -696,7 +696,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK& ty); } -template -auto concat_tuple_of_reference(ck::Tuple& tx, ck::Tuple& ty) -{ - return ck::unpack2( - [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, - tx, - ty); -} - template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp deleted file mode 100644 index 2d766e621b..0000000000 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/device_base.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/utility/functional4.hpp" -#include "ck/utility/tuple_helper.hpp" - -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -namespace ck { -namespace tensor_operation { -namespace host { - -template -struct ReferenceGemmMultiABD : public device::BaseOperator -{ - // Argument - struct Argument : public device::BaseArgument - { - Argument(const AsTensorTuple& as_m_k, - const BsTensorTuple& bs_k_n, - const DsTensorTuple& ds_m_n, - Tensor& e_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : as_m_k_{as_m_k}, - bs_k_n_{bs_k_n}, - ds_m_n_{ds_m_n}, - e_m_n_{e_m_n}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const AsTensorTuple& as_m_k_; - const BsTensorTuple& bs_k_n_; - const DsTensorTuple& ds_m_n_; - Tensor& e_m_n_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public device::BaseInvoker - { - using Argument = ReferenceGemmMultiABD::Argument; - - float Run(const Argument& arg) - { - static constexpr index_t NumATensor = AsTensorTuple::Size(); - static constexpr index_t NumBTensor = BsTensorTuple::Size(); - static constexpr index_t NumDTensor = DsTensorTuple::Size(); - - const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0]; - const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1]; - const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1]; - - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = generate_tie( - [&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, - Number{}); - auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); - unpack(arg.a_element_op_, data_refs); - } - } - - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = generate_tie( - [&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, - Number{}); - auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); - unpack(arg.b_element_op_, data_refs); - } - } - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - // compulsory - auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n)); - // optional (if multiple Ds) - auto data_refs2 = generate_tie( - [&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, - Number{}); - auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); - unpack(arg.cde_element_op_, data_refs); - } - } - - return 0; - } - - float Run(const device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - - static auto MakeArgument(const AsTensorTuple& as_m_k, - const BsTensorTuple& bs_k_n, - const DsTensorTuple& ds_m_n, - Tensor& e_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceGemmMultiABD" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -} // namespace host -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 0879bea4ea..6d97ec3a05 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -10,6 +10,7 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" namespace ck { namespace tensor_operation { @@ -20,7 +21,6 @@ using Multiply = ck::tensor_operation::element_wise::Multiply; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -#if defined(CK_USE_XDL) // RRR void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( std::vector, @@ -179,167 +179,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan PassThrough, Multiply, PassThrough>>>& instances); -#endif - -#if defined(CK_USE_WMMA) -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances); - -// RCR -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances); - -// CRR -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances); -#endif // CK_USE // GEMM + Add + Gelu template > op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -408,38 +246,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } @@ -483,7 +289,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -512,38 +317,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } @@ -587,7 +360,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -616,38 +388,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } @@ -691,7 +431,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -720,38 +459,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index fc60f48727..9d9a0e691c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -1,17 +1,13 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_AND_WMMA_KERNELS +# ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp - - device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp - device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp deleted file mode 100644 index a29f8513d8..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = Sequence; - -using BF16 = bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = tensor_layout::gemm::RowMajor; -using Col = tensor_layout::gemm::ColumnMajor; - -using Multiply = element_wise::Multiply; -using PassThrough = element_wise::PassThrough; -using AddFastGelu = element_wise::AddFastGelu; -using Add = element_wise::Add; -using FastGelu = element_wise::FastGelu; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< - // clang-format off - //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple, - Tuple, - AddFastGelu, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple, - Tuple, - Add, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple<>, - Tuple<>, - PassThrough, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple<>, - Tuple<>, - FastGelu, - GemmMNKPadding>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp deleted file mode 100644 index 2eaaaf009a..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = Sequence; - -using BF16 = bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = tensor_layout::gemm::RowMajor; -using Col = tensor_layout::gemm::ColumnMajor; - -using Multiply = element_wise::Multiply; -using PassThrough = element_wise::PassThrough; -using AddFastGelu = element_wise::AddFastGelu; -using Add = element_wise::Add; -using FastGelu = element_wise::FastGelu; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple, - Tuple, - AddFastGelu, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple, - Tuple, - Add, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple<>, - Tuple<>, - PassThrough, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple<>, - Tuple<>, - FastGelu, - GemmMNKPadding>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp deleted file mode 100644 index 3320b4afa6..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = Sequence; - -using BF16 = bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = tensor_layout::gemm::RowMajor; -using Col = tensor_layout::gemm::ColumnMajor; - -using Multiply = element_wise::Multiply; -using PassThrough = element_wise::PassThrough; -using AddFastGelu = element_wise::AddFastGelu; -using Add = element_wise::Add; -using FastGelu = element_wise::FastGelu; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< - // clang-format off - //######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | - //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple, - Tuple, - AddFastGelu, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple, - Tuple, - Add, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple<>, - Tuple<>, - PassThrough, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple<>, - Tuple<>, - FastGelu, - GemmMNKPadding>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index 6e72d379d0..23e3b7f511 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that -// portion of the instances are failing. As a workaround these have been commented out. template , S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 5eedb8b5ee..0560f159fc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that -// portion of the instances are failing. As a workaround these have been commented out. template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 7d1fcb5552..95365c82e7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that -// portion of the instances are failing. As a workaround these have been commented out. template +auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + template c_m_n({M, N}); + using AComputeType = typename std::conditional<(NumATensor > 1), EDataType, remove_cvref_t>>::type; + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(a_element_op, data_refs); + } + } + using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultiABD; + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(b_element_op, data_refs); + } + } - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = ref_gemm.MakeArgument( - as_m_k, bs_k_n, ds_m_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(cde_element_op, data_refs); + } + } } std::array as_device_buf; diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp deleted file mode 100644 index eea72b324d..0000000000 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ /dev/null @@ -1,534 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/env.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/convolution_parameter.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/utility/fill.hpp" - -namespace ck { -namespace profiler { - -template -auto reserveVector(std::size_t size) -{ - std::vector vec; - vec.reserve(size); - return vec; -} - -template -bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideDs, - const std::vector& StrideE, - const std::vector& kbatch_list = {1}, - int n_warmup = 1, - int n_iter = 10) -{ - bool pass = true; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - const std::size_t group_count = Ms.size(); - const int sum_of_m = std::accumulate(Ms.begin(), Ms.end(), 0); - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - if(group_count != Ns.size() || group_count != Ks.size() || group_count != StrideAs.size() || - group_count != StrideBs.size() || (NumDTensor > 0 && group_count != StrideDs.size())) - { - throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideAs/Bs/Ds/E size\n"); - } - - auto generateInputTupleA = [&](std::size_t g) { - if constexpr(NumATensor == 0) - { - static_assert("Gemm problem should have at least 1 A tensor."); - } - else - { - using ALayout = remove_cvref_t{}, AsLayout>>; - return generate_tuple( - [&](auto i) { - using ADataType = remove_cvref_t>; - return Tensor( - f_host_tensor_descriptor(Ms[g], Ks[g], StrideAs[g], ALayout{})); - }, - Number{}); - } - }; - auto generateInputTupleB = [&](std::size_t g) { - if constexpr(NumBTensor == 0) - { - static_assert("Gemm problem should have at least 1 B tensor."); - } - else - { - using BLayout = remove_cvref_t{}, BsLayout>>; - return generate_tuple( - [&](auto i) { - using BDataType = remove_cvref_t>; - return Tensor( - f_host_tensor_descriptor(Ks[g], Ns[g], StrideBs[g], BLayout{})); - }, - Number{}); - } - }; - auto generateInputTupleD = [&](std::size_t g) { - if constexpr(NumDTensor == 0) - { - return ck::Tuple<>(); - } - else - { - using DLayout = remove_cvref_t{}, DsLayout>>; - return generate_tuple( - [&](auto i) { - using DDataType = remove_cvref_t>; - return Tensor( - f_host_tensor_descriptor(Ms[g], Ns[g], StrideDs[g], DLayout{})); - }, - Number{}); - } - }; - - using AsTensorTuple = decltype(generateInputTupleA(0)); - using BsTensorTuple = decltype(generateInputTupleB(0)); - using DsTensorTuple = decltype(generateInputTupleD(0)); - - auto g_as_m_k = reserveVector(group_count); - auto g_bs_k_n = reserveVector(group_count); - auto g_ds_m_n = reserveVector(group_count); - auto g_e_m_n_host_results = reserveVector>(group_count); - auto g_e_m_n_device_results = reserveVector>(group_count); - - for(std::size_t g = 0; g < group_count; g++) - { - auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); - auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); - auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); - - g_e_m_n_host_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); - g_e_m_n_device_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); - - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "group: " << g << std::endl; - static_for<0, NumATensor, 1>{}([&](auto i) { - std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; - }); - static_for<0, NumBTensor, 1>{}([&](auto i) { - std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; - }); - static_for<0, NumDTensor, 1>{}([&](auto i) { - std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; - }); - std::cout << "e_m_n: " << g_e_m_n_device_results[g].mDesc << std::endl; - } - - std::size_t num_thread = 1; - switch(init_method) - { - case 0: break; - case 1: - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - }); - - break; - default: - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - }); - } - } - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto cde_element_op = CDEElementOp{}; - - using DeviceMemPtr = std::unique_ptr; - std::vector> g_as_device_buf(group_count); - std::vector> g_bs_device_buf(group_count); - std::vector> g_ds_device_buf(group_count); - std::vector g_e_device_buf(group_count); - - std::vector> g_as_device_view(group_count); - std::vector> g_bs_device_view(group_count); - std::vector> g_ds_device_view(group_count); - std::vector g_e_device_view(group_count); - - auto g_gemm_descs = reserveVector(group_count); - - auto grouped_gemm_kernel_args_host = - reserveVector>( - group_count); - - for(std::size_t g = 0; g < group_count; g++) - { - std::array as_stride; - std::array bs_stride; - std::array ds_stride; - - auto& as_m_k = g_as_m_k[g]; - auto& as_device_buf = g_as_device_buf[g]; - auto& as_device_view = g_as_device_view[g]; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - as_device_buf[i] = std::make_unique(sizeof(ADataType) * Ms[g] * Ks[g]); - as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); - as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); - as_stride[i] = StrideAs[g]; - }); - - auto& bs_k_n = g_bs_k_n[g]; - auto& bs_device_buf = g_bs_device_buf[g]; - auto& bs_device_view = g_bs_device_view[g]; - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - bs_device_buf[i] = std::make_unique(sizeof(BDataType) * Ks[g] * Ns[g]); - bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data()); - bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer(); - bs_stride[i] = StrideBs[g]; - }); - - auto& ds_m_n = g_ds_m_n[g]; - auto& ds_device_buf = g_ds_device_buf[g]; - auto& ds_device_view = g_ds_device_view[g]; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - ds_device_buf[i] = std::make_unique(sizeof(DDataType) * Ms[g] * Ns[g]); - ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); - ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); - ds_stride[i] = StrideDs[g]; - }); - - g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); - g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); - - g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ - sum_of_m, - Ns[g], - Ks[g], - std::vector(as_stride.begin(), as_stride.end()), - std::vector(bs_stride.begin(), bs_stride.end()), - std::vector(ds_stride.begin(), ds_stride.end()), - StrideE[g]}); - - tensor_operation::device:: - GroupedGemmMultiABDKernelArgument - kernelArg{as_device_view, - bs_device_view, - ds_device_view, - g_e_device_view[g], - Ms[g], - Ns[g], - Ks[g], - as_stride, - bs_stride, - ds_stride, - StrideE[g]}; - - grouped_gemm_kernel_args_host.push_back(std::move(kernelArg)); - } - - using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; - - const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - if(op_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - float best_kbatch = 0; - - if(do_verification) - { - using AComputeType = - typename std::conditional<(NumATensor > 1), - EDataType, - remove_cvref_t>>::type; - - using BComputeType = - typename std::conditional<(NumBTensor > 1), - EDataType, - remove_cvref_t>>::type; - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultiABD; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - for(std::size_t i = 0; i < group_count; i++) - { - auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i], - g_bs_k_n[i], - g_ds_m_n[i], - g_e_m_n_host_results[i], - a_element_op, - b_element_op, - cde_element_op); - - ref_invoker.Run(ref_argument); - } - } - - // profile device GEMM instances - for(auto& gemm_ptr : op_ptrs) - { - auto argument_ptr = gemm_ptr->MakeArgumentPointer( - g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs); - - if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Gemm incompatible with runtime set parameters. Skipping..." - << std::endl; - } - - continue; - } - - DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); - gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer()); - - DeviceMem grouped_gemm_kernel_args_dev( - gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); - hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), - grouped_gemm_kernel_args_host.data(), - gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), - hipMemcpyHostToDevice)); - - gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), - grouped_gemm_kernel_args_dev.GetDeviceBuffer()); - gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - std::string gemm_name = gemm_ptr->GetTypeString(); - - for(const auto kbatch_curr : kbatch_list) - { - gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - for(std::size_t g = 0; g < group_count; g++) - { - g_e_device_buf[g]->SetZero(); - } - - float ave_time = invoker_ptr->Run( - argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - - if(do_verification) - { - bool instance_pass = true; - for(std::size_t g = 0; g < group_count; g++) - { - g_e_device_buf[g]->FromDevice( - g_e_m_n_device_results[g].mData.data(), - g_e_m_n_device_results[g].mDesc.GetElementSize() * sizeof(EDataType)); - - instance_pass = - instance_pass && ck::utils::check_err(g_e_m_n_device_results[g], - g_e_m_n_host_results[g]); - - if(do_log) - { - static_for<0, NumATensor, 1>{}([&](auto i) { - LogRangeAsType( - std::cout << "a[" << g << "]: ", g_as_m_k[g](i).mData, ",") - << std::endl; - }); - static_for<0, NumBTensor, 1>{}([&](auto i) { - LogRangeAsType( - std::cout << "b[" << g << "]: ", g_bs_k_n[g](i).mData, ",") - << std::endl; - }); - static_for<0, NumDTensor, 1>{}([&](auto i) { - LogRangeAsType( - std::cout << "d[" << g << "]: ", g_ds_m_n[g](i).mData, ",") - << std::endl; - }); - LogRangeAsType( - std::cout << "e_device: ", g_e_m_n_device_results[g].mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "e_host : ", g_e_m_n_host_results[g].mData, ",") - << std::endl; - } - } - - std::cout << "Instance: " << gemm_name << " verification " - << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; - - pass = pass && instance_pass; - } - - if(time_kernel) - { - std::size_t flop = 0, num_btype = 0; - for(std::size_t g = 0; g < group_count; g++) - { - flop += std::size_t(2) * Ms[g] * Ns[g] * Ks[g]; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - num_btype += sizeof(ADataType) * Ms[g] * Ks[g]; - }); - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - num_btype += sizeof(BDataType) * Ks[g] * Ns[g]; - }); - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - num_btype += sizeof(DDataType) * Ms[g] * Ns[g]; - }); - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops - << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " - << kbatch_curr << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - best_kbatch = kbatch_curr; - } - } - } - else - { - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" - << std::endl; - } - } - } - - if(time_kernel) - { - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch - << std::endl; - } - return pass; -} - -} // namespace profiler -} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index bc79c85e59..450950cbd6 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,12 +18,6 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() - - add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp) - if(result EQUAL 0) - target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance) - add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk) - endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp deleted file mode 100644 index 610e7f2b77..0000000000 --- a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include - -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/utility/data_type.hpp" - -#include "ck/ck.hpp" -#include "ck/utility/type.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" - -#include "gtest/gtest.h" - -static ck::index_t param_mask = 0xffffff; -static ck::index_t instance_index = -1; - -using FP32 = float; -using FP16 = ck::half_t; -using BF16 = ck::bhalf_t; -using I8 = int8_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using Add = ck::tensor_operation::element_wise::Add; -using Multiply = ck::tensor_operation::element_wise::Multiply; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; - -// clang-format off -using KernelTypes = ::testing::Types< - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu> ->; -// clang-format on - -template -class TestGroupedGemmMultiABDFixedNK : public testing::Test -{ - protected: - using AsDataType = std::tuple_element_t<0, Tuple>; - using BsDataType = std::tuple_element_t<1, Tuple>; - using DsDataType = std::tuple_element_t<2, Tuple>; - using EDataType = std::tuple_element_t<3, Tuple>; - using AccDataType = float; - using AsLayout = std::tuple_element_t<4, Tuple>; - using BsLayout = std::tuple_element_t<5, Tuple>; - using DsLayout = std::tuple_element_t<6, Tuple>; - using ELayout = std::tuple_element_t<7, Tuple>; - using AElementOp = PassThrough; - using BElementOp = Multiply; - using CDEElementOp = std::tuple_element_t<8, Tuple>; - - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - public: - static constexpr bool verify_ = true; - static constexpr int init_method_ = 1; // integer value initialization - static constexpr bool log_ = false; - static constexpr bool bench_ = false; // measure kernel performance - static constexpr int n_warmup_ = 0; - static constexpr int n_iter_ = 1; - - std::vector k_batches_ = {1}; - - private: - template - void SetStrides(std::vector& strides, - const std::vector& rows, - const std::vector& cols) const - { - if(std::is_same_v) - { - for(const auto c : cols) - { - strides.emplace_back(c); - } - } - else if(std::is_same_v) - { - for(const auto r : rows) - { - strides.emplace_back(r); - } - } - } - - template - void SetTupleStrides(std::vector& strides, - const std::vector& rows, - const std::vector& cols) const - { - if constexpr(Layouts::Size() > 0) - { - // As of now multi ABD implementation supports only tensors with matching layouts. - using Layout = ck::remove_cvref_t{}, Layouts>>; - SetStrides(strides, rows, cols); - } - } - - public: - void Run(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs = {}, - const std::vector& StrideBs = {}, - const std::vector& StrideDs = {}, - const std::vector& StrideE = {}) - { - std::vector stride_as = StrideAs; - std::vector stride_bs = StrideBs; - std::vector stride_ds = StrideDs; - std::vector stride_e = StrideE; - - if(stride_as.empty()) - { - SetTupleStrides(stride_as, Ms, Ks); - } - if(stride_bs.empty()) - { - SetTupleStrides(stride_bs, Ks, Ns); - } - if(stride_ds.empty()) - { - SetTupleStrides(stride_ds, Ms, Ns); - } - if(stride_e.empty()) - { - SetStrides(stride_e, Ms, Ns); - } - - RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e); - } - - void RunSingle(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideDs, - const std::vector& StrideE) - { - bool pass = - ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideDs, - StrideE, - k_batches_, - n_warmup_, - n_iter_); - EXPECT_TRUE(pass); - } -}; - -TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes); - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) -{ - const std::vector Ms{3, 4}; - constexpr int N = 8; - constexpr int K = 64; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) -{ - const std::vector Ms{3, 5, 16, 7, 8}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) -{ - const std::vector Ms{167, 183, 177, 153, 139, 204}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) -{ - const std::vector Ms{64, 128, 256}; - constexpr int N = 768; - constexpr int K = 320; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) - { - // Run with default arguments. - } - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -}