mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Merge commit '77123600ee4b6fae077a2145b68b00a8b2ce9460' into develop
This commit is contained in:
@@ -634,7 +634,7 @@ option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_hp_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
@@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
|
||||
@@ -167,7 +167,7 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -600,12 +600,12 @@ struct Tensor
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
@@ -0,0 +1,759 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument
|
||||
karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = batch;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// @brief \"Universal\" Batched GEMM operation without SplitK support.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations applied to the A, B, and C tensors, respectively.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through its design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation currently does not support the SplitK algorithm.
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled). Currently not supported!
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
|
||||
// permuteB arguments so for now we are not including this functionality.
|
||||
static_assert(PermuteA == false,
|
||||
"Permute A functionality not supported by DeviceBatchedGemm operations.\n");
|
||||
static_assert(PermuteB == false,
|
||||
"Permute B functionality not supported by DeviceBatchedGemm operations.\n");
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by DeviceBatchedGemm base class.
|
||||
false>; // PermuteB not supported by DeviceBatchedGemm base class.
|
||||
|
||||
// Argument
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t BatchStrideA_,
|
||||
index_t BatchStrideB_,
|
||||
index_t BatchStrideC_,
|
||||
index_t Batch_,
|
||||
index_t k_batch_,
|
||||
bool is_reduce_ = false)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
is_reduce_),
|
||||
Batch(Batch_),
|
||||
compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
|
||||
{
|
||||
}
|
||||
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
};
|
||||
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
/// @note If appropriately configured it may measure kernel execution time.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
// The normal approach to batching would be to increase the grid size by just stretching
|
||||
// out the grid Z dimension (which is the outermost dimension), but this depends on
|
||||
// lower level functions not directly using the Z dimension for other calculations. As
|
||||
// it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
|
||||
// Therefore, for now we will use the grid Y dimension for batching. This may be a bit
|
||||
// fragile.
|
||||
gdy *= arg.Batch;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
// Packed sizes are 1 for all implemented data types but we include it anyway
|
||||
// for future compatibility.
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
// Note: the grid descriptors and size_a / size_b do *not* take batching into
|
||||
// account, so we have to manually multiply overall buffer sizes for rotating
|
||||
// memory by batch.
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
arg_.Batch * size_a_buffer,
|
||||
arg_.Batch * size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
// Note: we multiply by batch since we want to clear the C matrix for
|
||||
// the whole batch. Untested since we don't have k batching ATM.
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_,
|
||||
arg_.compute_ptr_offset_of_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto clear_workspace = [&]() {
|
||||
// clear c mem
|
||||
if(arg.KBatch > 1)
|
||||
// Note: we multiply by batch since we want to clear the C matrix for
|
||||
// the whole batch. Untested since we don't have k batching ATM.
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg,
|
||||
arg.compute_ptr_offset_of_batch);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Implement
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
|
||||
std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
// TODO: This is not part of the DeviceBatchedGemm base class but it was part of
|
||||
// DeviceBatchedGemmV2. Remove?
|
||||
// index_t GetKPerBlock() override { return KPerBlock; }
|
||||
// bool GetPermuteA() override { return PermuteA; }
|
||||
// bool GetPermuteB() override { return PermuteB; }
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch,
|
||||
1 /* KBatch */};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch,
|
||||
1); // KBatch
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -230,7 +230,7 @@ struct HostTensorDescriptor
|
||||
* @param iss Vector containing the multi-dimensional indices
|
||||
* @return The calculated linear offset as a size_t
|
||||
*/
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -540,9 +540,12 @@ struct HostTensor
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,46 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif // CK_ENABLE_BF16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
@@ -124,6 +164,8 @@ void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -154,6 +196,66 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_BF16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
@@ -258,6 +360,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_USE_XDL
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -295,11 +295,8 @@ FOREACH(subdir_path ${dir_list})
|
||||
|
||||
if(MIOPEN_REQ_LIBS_ONLY)
|
||||
message(STATUS "Removing all sources that are not required for MIOpen")
|
||||
if("${cmake_instance}" MATCHES "gemm" OR
|
||||
"${cmake_instance}" MATCHES "mha" OR
|
||||
"${cmake_instance}" MATCHES "contraction" OR
|
||||
"${cmake_instance}" MATCHES "reduce")
|
||||
set(add_inst 0)
|
||||
if(NOT "${cmake_instance}" MATCHES "conv")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -328,7 +325,7 @@ ENDFOREACH()
|
||||
|
||||
|
||||
|
||||
if(CK_DEVICE_OTHER_INSTANCES)
|
||||
if(CK_DEVICE_OTHER_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
add_library(device_other_operations ${CK_DEVICE_OTHER_INSTANCES})
|
||||
add_library(composablekernels::device_other_operations ALIAS device_other_operations)
|
||||
set_target_properties(device_other_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(BATCHED_GEMM_INSTANCES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
|
||||
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -67,7 +67,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
@@ -92,6 +91,7 @@ endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
@@ -164,7 +164,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
@@ -206,6 +205,7 @@ endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp)
|
||||
add_gtest_executable(test_batched_gemm_xdl test_batched_gemm_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
|
||||
target_link_libraries(test_batched_gemm_xdl PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_wmma test_batched_gemm_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_wmma PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
|
||||
193
test/batched_gemm/test_batched_gemm_wmma.cpp
Normal file
193
test/batched_gemm/test_batched_gemm_wmma.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_batched_gemm_impl.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
|
||||
|
||||
struct GemmParams
|
||||
{
|
||||
ck::index_t M;
|
||||
ck::index_t N;
|
||||
ck::index_t K;
|
||||
ck::index_t BatchCount;
|
||||
};
|
||||
|
||||
class TestBatchedGemm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
std::vector<GemmParams> params;
|
||||
|
||||
template <typename DataType>
|
||||
void Run()
|
||||
{
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
bool pass = true;
|
||||
for(auto& param : params)
|
||||
{
|
||||
const auto M = param.M;
|
||||
const auto N = param.N;
|
||||
const auto K = param.K;
|
||||
const auto BatchCount = param.BatchCount;
|
||||
|
||||
pass =
|
||||
pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Row,
|
||||
Row,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
pass =
|
||||
pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
pass =
|
||||
pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Col,
|
||||
Row,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
|
||||
|
||||
pass =
|
||||
pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Col,
|
||||
Col,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
// #ifdef CK_ENABLE_INT8
|
||||
// TEST_F(TestBatchedGemm, i8)
|
||||
// {
|
||||
// this->params.push_back({64, 64, 64, 2});
|
||||
// this->params.push_back({64, 64, 64, 1});
|
||||
// this->params.push_back({60, 60, 60, 2});
|
||||
// this->params.push_back({68, 68, 68, 2});
|
||||
// this->params.push_back({40, 40, 40, 2});
|
||||
// this->params.push_back({256, 256, 128, 3});
|
||||
// this->template Run<int8_t>();
|
||||
// }
|
||||
// #endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
TEST_F(TestBatchedGemm, bf16)
|
||||
{
|
||||
this->params.push_back({64, 64, 64, 2});
|
||||
this->params.push_back({64, 64, 64, 1});
|
||||
this->params.push_back({40, 40, 40, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
|
||||
// Tests with larger MNK
|
||||
this->params.push_back({512, 256, 128, 1});
|
||||
this->params.push_back({256, 240, 192, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
this->params.push_back({240, 128, 128, 5});
|
||||
this->template Run<ck::bhalf_t>();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
TEST_F(TestBatchedGemm, fp16)
|
||||
{
|
||||
this->params.push_back({64, 64, 64, 2});
|
||||
this->params.push_back({64, 64, 64, 1});
|
||||
this->params.push_back({40, 40, 40, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
|
||||
// Tests with larger MNK
|
||||
this->params.push_back({512, 256, 128, 1});
|
||||
this->params.push_back({256, 240, 192, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
this->params.push_back({240, 128, 128, 5});
|
||||
this->template Run<ck::half_t>();
|
||||
}
|
||||
#endif
|
||||
|
||||
// #ifdef CK_ENABLE_FP32
|
||||
// TEST_F(TestBatchedGemm, fp32)
|
||||
// {
|
||||
// this->params.push_back({64, 64, 64, 2});
|
||||
// this->params.push_back({64, 64, 64, 1});
|
||||
// this->params.push_back({60, 60, 60, 2});
|
||||
// this->params.push_back({68, 68, 68, 2});
|
||||
// this->params.push_back({40, 40, 40, 2});
|
||||
// this->params.push_back({256, 256, 128, 3});
|
||||
// this->template Run<float>();
|
||||
// }
|
||||
// #endif
|
||||
Reference in New Issue
Block a user