Device implementation of explicit gemm for grouped conv bwd weight

Based on batched gemm multiple D
This commit is contained in:
Enrico Degregori
2025-08-08 16:23:04 +00:00
parent 207cc39ee4
commit 70238cab87
47 changed files with 2820 additions and 149 deletions

View File

@@ -0,0 +1,756 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename ComputePtrOffsetOfStridedBatch,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_multi_d_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
// we will use the grid Y dimension for batching. This may be a bit fragile.
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
static_for<0, GridwiseGemm::NumATensor, 1>{}(
[&](auto i) { splitk_batch_offset.a_k_split_offset[i] += a_batch_offset; });
static_for<0, GridwiseGemm::NumBTensor, 1>{}(
[&](auto i) { splitk_batch_offset.b_k_split_offset[i] += b_batch_offset; });
splitk_batch_offset.c_reduce_offset += c_batch_offset;
// populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
});
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = compute_ptr_offset_of_batch;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
: public DeviceBatchedGemmV2MultiD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
using CDataType_ = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
false,
false>;
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(
index_t BatchStrideA,
index_t BatchStrideB,
std::array<ck::index_t, GridwiseGemm::NumDTensor> BatchStrideDs,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(BatchStrideA_) * g_idx;
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(BatchStrideB_) * g_idx;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, GridwiseGemm::NumDTensor> ds_offset_;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
ds_offset_[i] = static_cast<long_index_t>(BatchStrideDs_[i]) * g_idx;
});
return ds_offset_;
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(BatchStrideC_) * g_idx;
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
std::array<ck::index_t, GridwiseGemm::NumDTensor> BatchStrideDs_;
index_t BatchStrideC_;
};
struct Argument : public GridwiseGemm::Argument
{
index_t Batch;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
Argument() = default;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, GridwiseGemm::NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs_,
index_t StrideE_,
index_t BatchStrideA_,
index_t BatchStrideB_,
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs_,
index_t BatchStrideE_,
index_t Batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
index_t KBatch_)
: GridwiseGemm::Argument{std::array<const void*, 1>{p_a_grid_},
std::array<const void*, 1>{p_b_grid_},
p_ds_grid_,
p_e_grid_,
M_,
N_,
K_,
std::array<index_t, 1>{StrideA_},
std::array<index_t, 1>{StrideB_},
StrideDs_,
StrideE_,
KBatch_,
a_element_op_,
b_element_op_,
cde_element_op_,
false},
Batch{Batch_},
compute_ptr_offset_of_batch{
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
{
}
template <typename EType>
void SetEPointer(void* ptr)
{
this->p_e_grid = static_cast<EType*>(ptr);
}
};
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// TODO
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
gdy *= arg.Batch;
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
// Packed sizes are 1 for all implemented data types but we include it anyway
// for future compatibility.
std::array<std::size_t, 1> size_as_buffers;
size_as_buffers[0] = arg_.Batch *
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
std::array<std::size_t, 1> size_bs_buffers;
size_bs_buffers[0] = arg_.Batch *
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiABD<Argument,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType>
rotating_mem(arg_,
stream_config.rotating_count,
size_as_buffers,
size_bs_buffers,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_e_grid,
0,
arg.Batch * arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_,
arg_.compute_ptr_offset_of_batch);
}
else
{
const auto clear_workspace = [&]() {
if(arg.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg.p_e_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
};
ave_time =
launch_and_time_kernel_with_preprocess(stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg,
arg.compute_ptr_offset_of_batch);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t KBatch = 1)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
Batch,
a_element_op,
b_element_op,
cde_element_op,
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, GridwiseGemm::NumDTensor>& p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
Batch,
a_element_op,
b_element_op,
cde_element_op,
KBatch);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceBatchedGemmMultipleD_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(ELayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
static ck::index_t GetMaxOccupancy()
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
return active_workgroups_per_cu.max_occupancy_;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -350,6 +350,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
{
}
template <typename EType>
void SetEPointer(void* ptr)
{
this->p_c_grid = static_cast<EType*>(ptr);
}
};
using Argument = ArgumentBase<GridwiseGemm64>;

View File

@@ -32,7 +32,7 @@ template <ck::index_t NDimSpatial,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename DeviceGemmV3Op>
struct DeviceGroupedConvBwdWeight_Explicit_Xdl
struct DeviceGroupedConvBwdWeight_Explicit
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
@@ -56,7 +56,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
sizeof(WeiDataType) % 4 != 0 &&
DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit;
using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
static constexpr index_t ElementwiseBlockSize = 256;
@@ -288,8 +288,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
{
// Modify to use workspace as output
GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
explicit_gemm_args_with_workspace.p_c_grid =
static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
explicit_gemm_args_with_workspace.template SetEPointer<TwoStageIntermediateType>(
arg.p_workspace_);
float avg_time =
explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
const index_t grid_size =
@@ -494,7 +494,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Explicit_Xdl"
str << "DeviceGroupedConvBwdWeight_Explicit"
<< "<" << DeviceGemmV3Op{}.GetTypeString() << ">";
// clang-format on

View File

@@ -332,6 +332,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
struct Problem
{
__host__ Problem() = default;
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
@@ -406,6 +407,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument() = default;
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
@@ -472,9 +474,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;

View File

@@ -7,7 +7,7 @@
#include <type_traits>
#include "ck/utility/functional2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp"
namespace ck {
namespace tensor_operation {
@@ -32,17 +32,17 @@ void add_explicit_gemm_device_operation_instances(
ck::static_for<0, std::tuple_size_v<DeviceGemmV3Ops>, 1>{}([&](auto i) {
using DeviceGemmOp = std::tuple_element_t<i, DeviceGemmV3Ops>;
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit_Xdl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
DeviceGemmOp>;
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
DeviceGemmOp>;
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
"wrong! NewOpInstance should be derived from BaseOp");

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using BF16 = bhalf_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <typename InOutDataType, GemmSpecialization GemmSpec>
using device_gemm_wmma_universal_km_kn_mn_comp_instances = std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_wmma_universal_km_kn_mn_mem_instances = std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// Memory friendly
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances = std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_wmma_universal_km_kn_mn_odd_n_instances = std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -21,6 +21,7 @@
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_backward_weight_wmma.inc"
#include "grouped_convolution_backward_weight_explicit_wmma.inc"
#endif
namespace ck {
namespace tensor_operation {
@@ -395,21 +396,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -452,23 +456,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
op_ptrs);
}
#endif
@@ -641,21 +645,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -698,23 +705,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
op_ptrs);
}
#endif
@@ -888,6 +895,25 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -918,6 +944,25 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
}
#endif
}
@@ -1024,6 +1069,25 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -1054,6 +1118,25 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
}
#endif
}

View File

@@ -0,0 +1,459 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// 2D
#ifdef CK_ENABLE_BF16
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// 3D
#ifdef CK_ENABLE_BF16
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -10,7 +10,7 @@ namespace instance {
// 2D
#ifdef CK_ENABLE_BF16
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -22,7 +22,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -34,7 +34,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -46,7 +46,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -70,7 +70,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -82,7 +82,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -94,7 +94,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -106,7 +106,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -121,7 +121,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -133,7 +133,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -145,7 +145,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -157,7 +157,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -169,7 +169,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -181,7 +181,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -193,7 +193,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -205,7 +205,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -217,7 +217,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -232,7 +232,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
// 3D
#ifdef CK_ENABLE_BF16
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -244,7 +244,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -256,7 +256,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -268,7 +268,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -280,7 +280,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -292,7 +292,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -304,7 +304,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -316,7 +316,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -328,7 +328,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -343,7 +343,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -355,7 +355,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -367,7 +367,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -379,7 +379,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -391,7 +391,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -403,7 +403,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -415,7 +415,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -427,7 +427,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -439,7 +439,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -1,26 +1,47 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GROUPED_CONVND_EXP_BWD_WEIGHT
# Explicit instances are common for 2d and 3d
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp
)
add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT})

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<BF16, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<BF16, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMNKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_odd_n_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_odd_n_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<F16, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<F16, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Interwave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Interwave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMNKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_odd_n_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_wmma_universal_km_kn_mn_odd_n_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -1,12 +1,12 @@
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_bilinear test_grouped_convnd_bwd_weight_bilinear.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_bilinear PRIVATE utility device_grouped_conv3d_bwd_weight_bilinear_instance)
elseif(GPU_TARGETS MATCHES "gfx1[12]")
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_bilinear test_grouped_convnd_bwd_weight_bilinear.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_bilinear PRIVATE utility device_grouped_conv3d_bwd_weight_bilinear_instance)