mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Conv:TF32: add more instances - 1 (#2867)
* conv:tf32:add more instances
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* remove gnhwc/ngchw/ngcdhw instances
[ROCm/composable_kernel commit: df97a286d5]
This commit is contained in:
@@ -54,6 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
|
||||
|
||||
using ComputeDataTypeBuf =
|
||||
conditional_t<std::is_same<ComputeDataType, ck::tf32_t>::value, float, ComputeDataType>;
|
||||
|
||||
static constexpr index_t AMmaKStride = KPack;
|
||||
static constexpr index_t BMmaKStride = KPack;
|
||||
|
||||
@@ -376,7 +379,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
@@ -386,7 +389,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
|
||||
@@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -185,9 +187,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -240,20 +242,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -301,20 +303,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -439,6 +441,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
@@ -486,9 +490,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -551,20 +555,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -640,20 +644,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -704,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
I1));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -714,7 +718,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
|
||||
@@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t AMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
|
||||
static constexpr index_t BMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
|
||||
|
||||
@@ -222,10 +224,12 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataTypeBuf) /
|
||||
// sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataTypeBuf) /
|
||||
// sizeof(ADataType) : sizeof(ComputeDataTypeBuf)
|
||||
// / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
@@ -351,9 +355,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// assume kperblock = scaleblockk
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
@@ -516,17 +520,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
@@ -535,7 +539,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -646,17 +650,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
@@ -665,7 +669,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -737,17 +741,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
@@ -756,7 +760,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -791,17 +795,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
@@ -810,7 +814,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -842,7 +846,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
@@ -852,7 +856,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
|
||||
@@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -205,12 +207,12 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
{
|
||||
// assume kperblock = scaleblockk
|
||||
ignore = num_loop_per_scale;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -279,20 +281,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -360,20 +362,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
|
||||
@@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
@@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -284,20 +286,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -355,20 +357,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -410,20 +412,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -461,20 +463,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -628,6 +630,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::WaveSize;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
@@ -716,9 +720,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -786,20 +790,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -885,20 +889,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -961,20 +965,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -1037,20 +1041,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -1129,7 +1133,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
I1));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -1139,7 +1143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
|
||||
@@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
@@ -257,9 +259,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
{
|
||||
// assume kperblock = scaleblockk
|
||||
ignore = num_loop_per_scale;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
@@ -351,20 +353,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -457,20 +459,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -547,20 +549,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
@@ -605,20 +607,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
|
||||
@@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
@@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -285,20 +287,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -356,20 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -411,20 +413,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -462,20 +464,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -629,6 +631,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::WaveSize;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
@@ -732,12 +736,12 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
{
|
||||
ignore = num_loop_per_scale;
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -821,20 +825,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -942,20 +946,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -1039,20 +1043,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -1123,20 +1127,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -1223,7 +1227,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
I1));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -1233,7 +1237,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
|
||||
@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -196,10 +198,10 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
@@ -295,9 +297,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t num_loop) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -364,20 +366,20 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -424,20 +426,20 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
@@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -196,10 +198,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
@@ -329,9 +331,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
|
||||
"Pipeline v3 only support scaleblocksliceN=1");
|
||||
// assume kperblock = scaleblockk
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
@@ -476,20 +478,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
@@ -578,20 +580,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
|
||||
@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -195,10 +197,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
@@ -307,13 +309,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// B scale buffer
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -429,20 +431,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -491,20 +493,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 2;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -264,9 +266,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
|
||||
@@ -369,22 +371,22 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -439,20 +441,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -492,20 +494,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -524,20 +526,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 2;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
@@ -277,13 +279,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// B scale buffer
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
|
||||
@@ -478,22 +480,22 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -549,20 +551,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -603,20 +605,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -635,20 +637,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
@@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
|
||||
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
@@ -346,9 +348,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
@@ -405,8 +407,8 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto vmem_buf) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
if constexpr(k0 == (KRepeat - 1))
|
||||
@@ -427,18 +429,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
typename vector_type<ComputeDataTypeBuf,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
@@ -481,8 +483,8 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
// tail
|
||||
auto ReadWriteCompFunc = [&](auto vmem_buf) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
if constexpr(k0 == (KRepeat - 1))
|
||||
@@ -497,18 +499,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -540,25 +542,25 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
HotLoopScheduler();
|
||||
};
|
||||
auto ReadCompFunc = [&]() {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -591,16 +593,16 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
|
||||
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -637,7 +639,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
@@ -647,7 +649,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
ComputeDataTypeBuf,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
|
||||
@@ -107,8 +107,11 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
using BComputeDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
|
||||
#else
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
// Element data type is used in LDS and registers. ComputeDataType_ is inside mfma, eg tf32.
|
||||
using AElementDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
|
||||
using BElementDataType =
|
||||
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -199,8 +202,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
|
||||
b_block_space_size_aligned * sizeof(BComputeDataType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(AElementDataType) +
|
||||
b_block_space_size_aligned * sizeof(BElementDataType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -621,7 +624,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
Tuple<AComputeDataType>,
|
||||
Tuple<AElementDataType>,
|
||||
decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(tie(a_block_desc_ak0_m_ak1)),
|
||||
AElementwiseOperation,
|
||||
@@ -649,7 +652,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
Tuple<BComputeDataType>,
|
||||
Tuple<BElementDataType>,
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(tie(b_block_desc_bk0_n_bk1)),
|
||||
BElementwiseOperation,
|
||||
@@ -679,27 +682,28 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
(((is_same<AComputeDataType_, half_t>::value ||
|
||||
is_same<AComputeDataType_, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
|
||||
(is_same<AComputeDataType_, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<AComputeDataType_, f8_t>::value ||
|
||||
is_same<AComputeDataType_, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MfmaSelector<AComputeDataType_,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
BComputeDataType_,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AElementDataType,
|
||||
BElementDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -709,8 +713,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched,
|
||||
AComputeDataType,
|
||||
BComputeDataType>();
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -719,10 +723,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<AElementDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BElementDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
|
||||
@@ -164,6 +164,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using ElementDataTypeAB = conditional_t<is_same_v<FloatAB, ck::tf32_t>, float, FloatAB>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
|
||||
@@ -236,8 +238,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// Argument
|
||||
struct Argument : public Problem, public tensor_operation::device::BaseArgument
|
||||
{
|
||||
__host__ Argument(const FloatAB* p_a_grid_,
|
||||
const FloatAB* p_b_grid_,
|
||||
__host__ Argument(const ElementDataTypeAB* p_a_grid_,
|
||||
const ElementDataTypeAB* p_b_grid_,
|
||||
FloatC* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
@@ -252,8 +254,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
}
|
||||
|
||||
const FloatAB* p_a_grid;
|
||||
const FloatAB* p_b_grid;
|
||||
const ElementDataTypeAB* p_a_grid;
|
||||
const ElementDataTypeAB* p_b_grid;
|
||||
FloatC* p_c_grid;
|
||||
};
|
||||
|
||||
@@ -329,7 +331,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto b_block_space_size_aligned =
|
||||
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
|
||||
return (a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ElementDataTypeAB);
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -450,8 +453,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatABAdjusted,
|
||||
FloatABAdjusted,
|
||||
ElementDataTypeAB,
|
||||
ElementDataTypeAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
@@ -459,7 +462,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1>;
|
||||
K1,
|
||||
FloatABAdjusted,
|
||||
FloatABAdjusted>;
|
||||
|
||||
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
}
|
||||
@@ -471,8 +476,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N>
|
||||
__device__ static void Run(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
__device__ static void Run(const ElementDataTypeAB* p_a_grid,
|
||||
const ElementDataTypeAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
@@ -533,8 +538,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatABAdjusted,
|
||||
ElementDataTypeAB,
|
||||
ElementDataTypeAB,
|
||||
decltype(a_grid_desc_k0_m_k1),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -564,8 +569,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatABAdjusted,
|
||||
ElementDataTypeAB,
|
||||
ElementDataTypeAB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -595,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// sanity check
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
FloatABAdjusted,
|
||||
FloatABAdjusted,
|
||||
ElementDataTypeAB,
|
||||
ElementDataTypeAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
@@ -605,7 +610,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
K1,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
FloatABAdjusted,
|
||||
FloatABAdjusted>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -614,10 +621,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatABAdjusted*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
static_cast<ElementDataTypeAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<ElementDataTypeAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
@@ -1647,8 +1647,8 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -24,6 +24,7 @@ using BF8 = ck::bf8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -205,6 +206,27 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_f32_tf32_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// double rate mfma instances on gfx950
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
|
||||
@@ -253,6 +253,25 @@ using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_f32_tf32_generic_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE |
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -99,6 +100,27 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -24,6 +24,7 @@ using BF8 = ck::bf8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -64,7 +65,7 @@ using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple<
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Latency friendly
|
||||
// Latency friendly
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
@@ -163,6 +164,41 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_f32_tf32_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
|
||||
// Memory friendly
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -142,6 +143,27 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataTypes = Tuple<>,
|
||||
typename OutElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Instances with NumGroupsPerBatch > 1
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 8>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 16>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_scale_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_scale_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace instance {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -89,7 +90,7 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple<
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
// instances for small conv.K and conv.C
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4>,
|
||||
|
||||
@@ -97,6 +98,27 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
// instances for small conv.K and conv.C
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -428,26 +428,39 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
|
||||
is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs);
|
||||
if constexpr(is_same_v<AComputeType, BComputeType> && is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -127,24 +127,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
@@ -197,32 +217,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
|
||||
is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
@@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instance
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -508,6 +524,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -522,6 +554,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwg
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -536,6 +584,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_ins
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -550,6 +614,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intr
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -564,6 +644,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inte
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -622,6 +718,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndh
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -636,6 +748,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_nd
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -650,6 +778,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -664,6 +808,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_i
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -678,6 +838,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_i
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instanc
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
ck::Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -137,8 +153,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
if constexpr(is_same_v<ComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -125,23 +125,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
@@ -193,30 +214,42 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
|
||||
is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -508,6 +524,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -522,6 +554,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -536,6 +584,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -550,6 +614,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_ins
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
@@ -564,6 +644,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_ins
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -622,6 +718,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -636,6 +748,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -650,6 +778,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_insta
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -664,6 +808,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -678,6 +838,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -111,6 +111,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -281,6 +296,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW
|
||||
|
||||
@@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -169,6 +184,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW
|
||||
|
||||
@@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -169,6 +184,21 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW
|
||||
|
||||
@@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
ck::Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -137,7 +153,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
if constexpr(is_same_v<ComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -211,6 +211,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
|
||||
@@ -55,6 +55,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -120,6 +136,22 @@ void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -84,6 +84,22 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
@@ -176,6 +192,22 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW
|
||||
|
||||
@@ -9,6 +9,7 @@ set(GROUPED_CONV2D_FWD
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp
|
||||
@@ -28,12 +29,14 @@ set(GROUPED_CONV2D_FWD
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instance.cpp
|
||||
# merged groups
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp
|
||||
@@ -44,9 +47,11 @@ set(GROUPED_CONV2D_FWD
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp
|
||||
@@ -61,6 +66,7 @@ set(GROUPED_CONV2D_FWD
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd3x3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -85,6 +85,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 2
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
# merged groups
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
@@ -114,6 +124,15 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
#mem
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
@@ -143,6 +162,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
@@ -171,6 +200,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
#comp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
@@ -200,7 +239,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -21,10 +21,16 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp16_comp_part2_instance.cpp
|
||||
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -21,10 +21,16 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance
|
||||
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp16_comp_part2_instance.cpp
|
||||
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -20,10 +20,12 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp
|
||||
@@ -31,13 +33,16 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd3x3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -94,6 +94,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 2
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
# merged groups
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
@@ -123,6 +133,15 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
#mem
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
@@ -154,6 +173,15 @@ generate_sharded_instantiations(
|
||||
)
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
|
||||
@@ -180,6 +208,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
#comp
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
@@ -210,6 +248,15 @@ generate_sharded_instantiations(
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -19,10 +19,15 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_BILINEAR
|
||||
xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bilinear_instance ${GROUPED_CONV3D_FWD_BILINEAR})
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
ck::Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -19,10 +19,15 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_BILINEAR
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR})
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
ck::Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user