diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index eff6b6f3fa..f98308d687 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) + add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp new file mode 100644 index 0000000000..97a3f89e5e --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#define DIRECT_LOAD 1 + +#if DIRECT_LOAD +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#endif + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +#if DIRECT_LOAD + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle_LdsDirectLoad + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; +// clang-format on + +#else + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#endif + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp new file mode 100644 index 0000000000..dd33e577bf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -0,0 +1,423 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + +struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeType>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Invoker + struct Invoker : public BaseInvoker + { + + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + + const auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0Padded = karg.K0Padded; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 4cee1ed34b..cd36b9e51a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/amd_lds.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -491,22 +492,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } - template - __device__ static auto AllocateBlockBuffers(void* p_shared, - int32_t num_elems, - int32_t offset_elems, - int32_t max_lds_align) - { - const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align); - return generate_tuple( - [&](auto i) { - const int32_t local_offset = i * single_buffer_offset; - return make_dynamic_buffer( - static_cast(p_shared) + local_offset + offset_elems, num_elems); - }, - Number{}); - } - template ( - p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align); + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; auto b_block_buffers = - AllocateBlockBuffers(p_shared, - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), - b_buffers_offset, - max_lds_align); + ck::lds_utils::AllocateLdsBuffers( + p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp new file mode 100644 index 0000000000..94306a4c95 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -0,0 +1,962 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); +#else + ignore = karg; + ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdlops_splitk_lds_direct_load +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0Padded; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0Padded(K0Padded_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0Padded:" << K0Padded << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) * + sizeof(ComputeType), + c_block_size * sizeof(FloatC)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + + const auto num_k_loop = karg.K0Padded / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; + return KPad; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) + { + const index_t num_loop = K0Padded / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; + + template + __device__ static void Run(const Argument& karg, + void* __restrict__ p_shared_block, + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [KBatch, M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_K0_M_K1, + FloatA, + ComputeType, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_K0_N_K1, + FloatB, + ComputeType, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, // ComputeType A + ComputeType, // ComputeType B + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_lds.hpp b/include/ck/utility/amd_lds.hpp new file mode 100644 index 0000000000..c218fded96 --- /dev/null +++ b/include/ck/utility/amd_lds.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +namespace lds_utils { + +/** \brief Allocate a given number of buffers in LDS and return them as a tuple. + * + * \tparam DataType Data type of elements to be stored in LDS. + * \tparam NumBuffers Number of buffers to be allocated. + * \param lds_ptr Address of the beginning of LDS space. + * \param num_elems_per_buffer Number of elements to allocate per single buffer. + * \param start_offset_elems Number of elements to move from the start of LDS for the allocation of + * the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of + * elements. \return Tuple of dynamic buffers representing memory allocated in LDS. + */ +template +__device__ static auto AllocateLdsBuffers(void* lds_ptr, + int32_t num_elems_per_buffer, + int32_t start_offset_elems, + int32_t lds_alignment) +{ + const DataType* lds_start = static_cast(lds_ptr) + start_offset_elems; + const int32_t single_buffer_offset = + math::integer_least_multiple(num_elems_per_buffer, lds_alignment); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer(lds_start + local_offset, + num_elems_per_buffer); + }, + Number{}); +} + +} // namespace lds_utils +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index 8ad6ddca9d..974da56649 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -36,6 +36,11 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); + +void add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); #endif #ifdef CK_ENABLE_FP32 void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( @@ -192,6 +197,7 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index ec4c27598f..aaa0d7e960 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -8,6 +8,7 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..f0a54ee400 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#######################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#######################################| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 4, 16, 16, 16, 1, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 4, 32, 16, 16, 1, 2, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 4, 32, 16, 16, 1, 2, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck