// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include #include namespace ck_tile { template struct GemmKernel { using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; using LayoutA = remove_cvref_t; using LayoutB = remove_cvref_t; using LayoutC = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CAccDataType = remove_cvref_t; using CODataType = remove_cvref_t; __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) { return TilePartitioner::GridSize(M_size, N_size, Batch_size); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } struct GemmCommonKargs { const void* a_ptr; const void* b_ptr; void* c_ptr; float epsilon; ck_tile::index_t M; ck_tile::index_t N; ck_tile::index_t K; ck_tile::index_t stride_A; ck_tile::index_t stride_B; ck_tile::index_t stride_C; }; CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, const void* b_ptr, void* c_ptr, float epsilon, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t stride_A, ck_tile::index_t stride_B, ck_tile::index_t stride_C) { return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C}; } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const { const auto [i_m, i_n] = TilePartitioner{}(); // options const ADataType* a_start = static_cast(kargs.a_ptr); const BDataType* b_start = static_cast(kargs.b_ptr); // Convert pointers to tensor views auto a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_start, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A), number{}, number<1>{}); } else { return make_naive_tensor_view( a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); auto b_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( b_start, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B), number{}, number<1>{}); } else { // Default NK layout return make_naive_tensor_view( b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), number{}, number<1>{}); } }(); auto a_pad_view = pad_tensor_view( a_tensor_view, make_tuple(number{}, number{}), sequence < 0, GemmPipeline::kPadA ? 1 : 0 > {}); auto ABlockWindow = make_tile_window( a_pad_view, make_tuple(number{}, number{}), {i_m, 0}); auto b_pad_view = pad_tensor_view( b_tensor_view, make_tuple(number{}, number{}), sequence < 0, GemmPipeline::kPadB ? 1 : 0 > {}); auto BBlockWindow = make_tile_window( b_pad_view, make_tuple(number{}, number{}), {i_n, 0}); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK; auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr); CODataType* c_start = static_cast(kargs.c_ptr); auto c_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( c_start, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), number{}, number<1>{}); } else { return make_naive_tensor_view( c_start, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), number{}, number<1>{}); } }(); auto c_pad_view = pad_tensor_view( c_tensor_view, make_tuple(number{}, number{}), sequence < 0, GemmPipeline::kPadC ? 1 : 0 > {}); auto CBlockWindow_pad = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); EpiloguePipeline{}(CBlockWindow_pad, acc); } }; } // namespace ck_tile