From cfac9497e28a7489d5cde5bf2b4f40691dd5659c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:18:05 -0700 Subject: [PATCH 1/8] remove gfx12 targets from daily builds with rocm6.2 (#1560) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index e61fb71e8e..a79ed859f2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1138,7 +1138,7 @@ pipeline { execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \ + -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ } steps{ From 2e1165c1a73552dbacf08ccd351314ae95de14f7 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:21:57 -0700 Subject: [PATCH 2/8] fix the target selection logic (#1561) --- CMakeLists.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ad6307cb3..3f22bb4b61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,7 +132,11 @@ if(GPU_ARCHS) unset(GPU_TARGETS CACHE) unset(AMDGPU_TARGETS CACHE) endif() - +if(GPU_TARGETS) + set(USER_GPU_TARGETS 1) +else() + set(USER_GPU_TARGETS 0) +endif() find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 @@ -162,7 +166,7 @@ endif() if(GPU_ARCHS) set(CK_GPU_TARGETS ${GPU_ARCHS}) else() - if(GPU_TARGETS) + if(USER_GPU_TARGETS) set(CK_GPU_TARGETS ${GPU_TARGETS}) endif() endif() From 6f27bc987248633255cc400437bd017dca70cf1e Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 10 Oct 2024 03:02:22 -0700 Subject: [PATCH 3/8] Ck tile gemm cshuffle & CK Tile GEMM restructure (#1535) * ake the cshuffle compilable * modify Mhe reference on gpu and cpu. Correaccess of cshuffle * fix the cpu reference code * Complete the in tile shuffle logic * restructure the kernel template input * change the naming pattern of ck_tile gemm pipeline * Re-format files using remod.py * Solve the fmha conflict with gemm * Comment Addressed from Carlus --------- Co-authored-by: Po Yen, Chen --- example/ck_tile/03_gemm/gemm_basic.cpp | 55 ++++-- .../ck_tile/core/container/thread_buffer.hpp | 2 +- .../ck_tile/host/reference/reference_gemm.hpp | 47 ++++- include/ck_tile/ops/epilogue.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 171 ++++++++++++++++++ ...block_fmha_bwd_pipeline_default_policy.hpp | 133 ++++++++------ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 81 +++++---- include/ck_tile/ops/gemm.hpp | 11 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 15 +- ... => gemm_pipeline_agmem_bgmem_creg_v1.hpp} | 10 +- ...ne_agmem_bgmem_creg_v1_default_policy.hpp} | 4 +- ... => gemm_pipeline_agmem_bgmem_creg_v2.hpp} | 6 +- ...ne_agmem_bgmem_creg_v2_default_policy.hpp} | 9 +- ..._problem.hpp => gemm_pipeline_problem.hpp} | 17 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 27 +++ 15 files changed, 447 insertions(+), 142 deletions(-) create mode 100644 include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v1.hpp => gemm_pipeline_agmem_bgmem_creg_v1.hpp} (95%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp => gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp} (99%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v2.hpp => gemm_pipeline_agmem_bgmem_creg_v2.hpp} (97%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp => gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp} (57%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_problem.hpp => gemm_pipeline_problem.hpp} (65%) create mode 100644 include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 9f790f6acb..e3c8d72590 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -41,18 +41,39 @@ template ; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + + // The rank and permutation will also be generate out by the CodeGen part. + constexpr ck_tile::index_t kOutputRank = 2; + + // Whether doing the CShuffle (transpose before the global memory), depending on the output + // layout. + constexpr bool CShuffleEpilogue = + std::is_same_v; + + using GemmEpilogue = std::conditional_t< + CShuffleEpilogue, + ck_tile::CShuffleEpilogue>, + ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = - ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKargs(args.p_a, args.p_b, @@ -255,15 +276,13 @@ int main(int argc, char* argv[]) ck_tile::sequence, ck_tile::sequence>; - using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem; + using CodegenGemmTraits = ck_tile:: + TileGemmTraits; - using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; invoke_gemm c_host_gpu_ref(c_dimensions); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); - ck_tile::reference_gemm_gpu( + ck_tile::reference_gemm_gpu( a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); c_buf.FromDevice(c_host_gpu_ref.data()); diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index a7dad5233b..279a48acb3 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -58,7 +58,7 @@ struct thread_buffer { template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } - + template ::value, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto _get_as() const diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index a0ddd02d9e..a496c91e00 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const int N = b_n_k.mDesc.get_lengths()[0]; + const int N = (std::is_same_v) + ? b_n_k.mDesc.get_lengths()[0] + : b_n_k.mDesc.get_lengths()[1]; const int K = (std::is_same_v) ? a_m_k.mDesc.get_lengths()[1] : a_m_k.mDesc.get_lengths()[0]; @@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, ADataType v_a = (std::is_same_v) ? a_element_op(a_m_k(m, k)) : a_element_op(a_m_k(k, m)); - BDataType v_b = b_element_op(b_n_k(n, k)); + BDataType v_b = (std::is_same_v) + ? b_element_op(b_n_k(n, k)) + : b_element_op(b_n_k(k, n)); v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); } - c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + CDataType& c_ref = (std::is_same_v) + ? c_m_n(m, n) + : c_m_n(n, m); + c_ref = ck_tile::type_convert(acc_element_op(v_acc)); } }; make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); } -template +template __global__ void naive_gemm_kernel(ADataType* A, BDataType* B, CDataType* C, @@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, if(row < M && col < N) { AccDataType acc = 0.0; - for(int k = 0; k < K; ++k) { - acc += static_cast(A[row * strideA + k]) * - static_cast(B[col * strideB + k]); + // Adjust indexing based on matrix layout + int a_index = (std::is_same_v) + ? row * strideA + k + : k * strideA + row; + int b_index = (std::is_same_v) + ? col * strideB + k + : k * strideB + col; + acc += static_cast(A[a_index]) * static_cast(B[b_index]); } - C[row * strideC + col] = acc; // Store as AccDataType + int c_index = (std::is_same_v) + ? row * strideC + col + : col * strideC + row; + C[c_index] = acc; } } -template +template void reference_gemm_gpu(DeviceMem& a_device, DeviceMem& b_device, DeviceMem& c_device, @@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; - naive_gemm_kernel + naive_gemm_kernel <<>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); errC = hipMemcpy( c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 388f52c898..a98f60b364 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -3,5 +3,6 @@ #pragma once +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp new file mode 100644 index 0000000000..9625b137bd --- /dev/null +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +#define CK_TILE_MAX_RANK 5 + +namespace ck_tile { + +// this epilogue aiming to store a matrix with different layout from the shared memory to the global +// memory. +template +struct CShuffleEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kTilePermute = kTilePermute_; + static constexpr index_t kRank = kRank_; + static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; + static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { + TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; +}; + +template +struct CShuffleEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + const index_t* kPerm = Problem::kPerm; + static constexpr bool kTilePermute = Problem::kTilePermute; + static constexpr index_t kRank = Problem::kRank; + const index_t* tile_sizes = Problem::tile_sizes; + + // No additional shared memory needed + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) + { + using DataType = typename OAccTile::DataType; + + // Get thread buffer + auto& thread_buf = o_acc_tile.get_thread_buffer(); + + // Create a temporary buffer to hold the permuted data + thread_buffer permuted_thread_buf; + + // Get the lengths of each dimension + auto thread_tensor_lengths = o_acc_tile.get_lengths(); + + // Total number of elements + index_t total_elements = OAccTile::kThreadElementSpaceSize; + + // Iterate over all elements + for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) + { + // Convert linear index to multi-dimensional indices + array indices; + index_t remaining = linear_idx; + static_for<0, kRank, 1>{}([&](auto i) { + constexpr auto rev_i = kRank - 1 - i; + indices(rev_i) = remaining % thread_tensor_lengths.get(number{}); + remaining /= thread_tensor_lengths.get(number{}); + }); + + // Apply the permutation + array permuted_indices; + static_for<0, kRank, 1>{}( + [&](auto i) { permuted_indices(i) = indices.get(number{}); }); + + // Compute offsets + index_t dst_offset = 0; + index_t stride = 1; + + static_for<0, kRank, 1>{}([&](auto i) { + constexpr auto rev_i = kRank - 1 - i; + dst_offset += permuted_indices[rev_i] * stride; + stride *= thread_tensor_lengths.get(number{}); + }); + + // Move the data + permuted_thread_buf(dst_offset) = thread_buf[linear_idx]; + } + + // Copy the permuted data back to the original thread buffer + for(index_t i = 0; i < total_elements; ++i) + { + thread_buf.set_as(i, permuted_thread_buf.get(i)); + } + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) + { + const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); + + // Compute the tile coordinates by dividing the window origin by the tile sizes + index_t tile_coords[CK_TILE_MAX_RANK] = {0}; + for(index_t i = 0; i < kRank; ++i) + { + tile_coords[i] = current_window_origin[i] / tile_sizes[i]; + // printf("The tile_coord is: %d", tile_coords[i]); + } + + // Apply the permutation to the tile coordinates + index_t permuted_tile_coords[CK_TILE_MAX_RANK]; + for(index_t i = 0; i < kRank; ++i) + { + permuted_tile_coords[i] = tile_coords[kPerm[i]]; + // printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]); + } + + // Compute the permuted window origin + index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0}; + for(index_t i = 0; i < kRank; ++i) + { + permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; + // printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); + } + + typename ODramWindowTmp::BottomTensorIndex step = {}; + for(index_t i = 0; i < kRank; ++i) + { + step[i] = permuted_window_origin[i] - current_window_origin[i]; + } + + // Move the window + move_tile_window(o_dram_window_tmp, step); + + // Permute the data within the tile if necessary + if constexpr(kTilePermute) + { + permute_tile_data(o_acc_tile); + } + + // Store the tile data to the permuted location + if constexpr(kPadM || kPadN) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 8647a7d25a..e1f05d39db 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -5,8 +5,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" @@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::QDataType, @@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::OGradDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::OGradDataType, - typename Problem::VDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm2BlockWarps, - typename Problem::BlockFmhaShape::Gemm2WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::OGradDataType, @@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm2BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::QDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm3BlockWarps, - typename Problem::BlockFmhaShape::Gemm3WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm3BlockWarps, + typename Problem::BlockFmhaShape::Gemm3WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::KDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm4BlockWarps, - typename Problem::BlockFmhaShape::Gemm4WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm4BlockWarps, + typename Problem::BlockFmhaShape::Gemm4WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } // these are for global load diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index ae9e320f67..4ea0c4c9f2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -5,8 +5,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" @@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::SaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2{}; } }; @@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::SaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmASmemBSmemCRegV1{}; + return BlockGemmASmemBSmemCRegV1{}; } }; @@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + TileGemmTraits>; auto warp_gemm = [&]() { if constexpr(std::is_same_v && @@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2{}; } }; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e9005462b0..dc5983e4d1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,12 +23,13 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e24d7f9ea0..48329c8ba5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -11,20 +11,12 @@ namespace ck_tile { -template +template struct GemmKernel { using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; using ADataType = remove_cvref_t; @@ -32,6 +24,10 @@ struct GemmKernel using CAccDataType = remove_cvref_t; using CODataType = remove_cvref_t; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) { return TilePartitioner::GridSize(M_size, N_size, Batch_size); @@ -184,6 +180,7 @@ struct GemmKernel c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); + EpiloguePipeline{}(CBlockWindow_pad, acc); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp similarity index 95% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index bec8a204cc..5ed7d036ea 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -4,15 +4,15 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCRegV1 +template +struct GemmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadC = Problem::kPadC; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { return ck_tile::integer_divide_ceil( diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp similarity index 99% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 3048adad67..8639f00fbb 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -7,9 +7,9 @@ namespace ck_tile { -// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 +// Default policy for GemmPipelineAGmemBGmemCRegV1 // Default policy class should not be templated, put template on member functions instead -struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy +struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { #if 0 // 2d diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp similarity index 97% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index ab5fe79114..bff7fc0a0e 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,15 +4,15 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCRegV2 +template +struct GemmPipelineAGmemBGmemCRegV2 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp similarity index 57% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp index 0596408501..7dad55d6b9 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp @@ -7,12 +7,11 @@ namespace ck_tile { -// Default policy for BlockGemmPipelineAGmemBGmemCRegV2 +// Default policy for GemmPipelineAGmemBGmemCRegV2 // Default policy class should not be templated, put template on member functions instead // NOTE: policy should be binded to its corresponding operation. It's just a coincidence that -// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as -// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy -using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy = - BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy; +// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as +// GemmPipelineAGmemBGmemCRegV1DefaultPolicy +using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp similarity index 65% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 8dfba08ad7..d7b3b24a4a 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -13,20 +13,23 @@ template -struct BlockGemmPipelineProblem + typename TileGemmTraits_> +struct GemmPipelineProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using GemmTraits = remove_cvref_t; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadA = kPadA_; - static constexpr bool kPadB = kPadB_; - static constexpr bool kPadC = kPadC_; + static constexpr bool kPadA = GemmTraits::kPadA; + static constexpr bool kPadB = GemmTraits::kPadB; + static constexpr bool kPadC = GemmTraits::kPadC; + + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp new file mode 100644 index 0000000000..98da1510c7 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmTraits +{ + static constexpr bool kPadA = kPadA_; + static constexpr bool kPadB = kPadB_; + static constexpr bool kPadC = kPadC_; + + using LayoutA = LayoutA_; + using LayoutB = LayoutB_; + using LayoutC = LayoutC_; +}; + +} // namespace ck_tile From d18fc0797ff483dee4446e643798be699713d22c Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:37:09 -0500 Subject: [PATCH 4/8] Fix default stride value (#1559) --- example/01_gemm/run_gemm_example_streamk_v2.inc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 6679f95157..32bd3a19a6 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == 0) { - // give a chance if stride is -1, return a default packed stride + // give a chance if stride is 0, return a default packed stride if constexpr(std::is_same_v) { return static_cast(col); From 14c52befdaadc392e93450df6b5501f70c43f34d Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Thu, 10 Oct 2024 16:57:23 -0400 Subject: [PATCH 5/8] removed API usage header (#1566) --- docs/reference/API_Reference_Guide.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/reference/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst index 22222b0cf0..0d2d41c1eb 100644 --- a/docs/reference/API_Reference_Guide.rst +++ b/docs/reference/API_Reference_Guide.rst @@ -12,12 +12,6 @@ API reference guide This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design principles that are used to write new classes that extend CK functionality. -================= -Using CK API -================= - -This section describes how to use the CK library API. - ================= CK Datatypes ================= From f46a9eee9dbcf44697b3dad27f0675ca6d877d99 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:31:56 -0700 Subject: [PATCH 6/8] only build tests and examples if user sets GPU_TARGETS (#1565) --- CMakeLists.txt | 2 +- README.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f22bb4b61..cfcfa24b37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -549,7 +549,7 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) add_subdirectory(library) -if(NOT GPU_ARCHS) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name diff --git a/README.md b/README.md index 34ac0919ae..4366ec0329 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets supported by the current compiler (this may take a long time). + Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line. NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you From 11444e4cf2d158500a10dbd2ace3bbd27cc65776 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:29:46 -0700 Subject: [PATCH 7/8] [CI] remove the --rm docker container flags (#1568) --- Jenkinsfile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index a79ed859f2..132257ad80 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -660,7 +660,7 @@ def process_results(Map conf=[:]){ def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group - def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } From 29d384d0b2f266ba8fbf3f7728d2bba4f5a7b852 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:05:11 +0200 Subject: [PATCH 8/8] Implement GetWorkSpaceSize from BaseOperator. (#1564) --- .../gpu/device/device_cgemm.hpp | 6 +++--- .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index 8484212118..44dedeeef9 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" @@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator CElementwiseOperation c_element_op, ck::index_t KBatch = 1) = 0; - virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::unique_ptr MakeInvokerPointer() = 0; virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, - index_t StrideC) = 0; + index_t StrideC) const = 0; }; template (base_arg); + + if(!parg) + { + std::ostringstream err; + err << "Provided argument pointer is not of an Argument class!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return GetWorkspaceSize( + parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC); + } }; } // namespace device