[WIP] initial implementation

This commit is contained in:
Graner, Johannes
2026-01-30 01:58:18 -05:00
parent 9b168082b7
commit d5bbd4c3f1
9 changed files with 560 additions and 239 deletions

View File

@@ -287,7 +287,8 @@ template <index_t NDimSpatial,
typename AComputeType = ADataType,
typename BComputeType = AComputeType,
index_t MaxTransposeTransferInScalarPerVector = 1,
index_t MaxTransposeTransferOutScalarPerVector = 1>
index_t MaxTransposeTransferOutScalarPerVector = 1,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image
@@ -387,7 +388,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
true, /*SplitConvN*/
ABDataType,
EDataType,
1,
NumGroupsToMerge,
index_t,
CTranspose>;
@@ -964,10 +965,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum));
gemms_grid_size_.push_back(grid_size);
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
// A/B/Ds/E Batch Stride (multiply by NumGroupsToMerge for group merging)
compute_ptr_offset_of_batch_.BatchStrideA_ =
a_g_n_k_wos_strides_transposed[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideB_ =
b_g_k_c_xs_strides_transposed[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideE_ =
e_g_n_c_wis_strides_transposed[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
@@ -1147,7 +1151,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
{
float ave_time = 0;
const index_t gdy = arg.num_group_;
const index_t gdy = arg.num_group_ / NumGroupsToMerge;
const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_;
const ADataType* p_a_grid = arg.p_a_grid_;
@@ -1542,6 +1546,21 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
// Validation: Check that NumGroupsToMerge divides ConvG evenly
if constexpr(NumGroupsToMerge > 1)
{
if(ConvG % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! Conv_G % NumGroupsToMerge != 0: Conv_G=" << ConvG
<< ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl;
}
return false;
}
}
// Specifialization
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
@@ -1908,7 +1927,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< TransposeTransferInScalarPerVectorAligned <<", "
<< "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
}
if constexpr(NumGroupsToMerge > 1)
{
str << ", NumGroupsToMerge: " << NumGroupsToMerge;
}
str << ">";

View File

@@ -508,9 +508,19 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(WoStride_, KStrideTensorA_));
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(WoStride_, KStrideTensorA_));
}
else
{
// Add NumGroupsToMerge dimension for group merging
const index_t BatchStride = NStrideTensorA_;
return make_naive_tensor_descriptor(
make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_),
make_tuple(WoStride_, BatchStride, KStrideTensorA_));
}
}
else
{
@@ -607,14 +617,105 @@ struct TransformConvBwdDataToGemm_v1
__host__ __device__ auto MakeWeiGridDesc() const
{
// Power-of-2 constraint for NumGroupsToMerge (required for XOR transform)
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
NumGroupsToMerge == 64);
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
if constexpr(NumGroupsToMerge == 1)
{
// Original V1 logic - no group merging
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
}
else
{
// Group merging logic - XOR + merge pattern from bwd_weight V2
// Add NumGroupsToMerge for M dimension and 1 as placeholder for N dimension
constexpr auto NumGroupsToMergeNumber = Number<NumGroupsToMerge>{};
const auto desc = make_naive_tensor_descriptor_packed(
make_tuple(NumGroupsToMergeNumber, K_, Y_ * X_, I1, C_));
// Pad placeholder dimension from 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMergeNumber),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pad_transform(I1, I0, NumGroupsToMergeNumber - I1),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// XOR transform to select diagonal (where group indices match)
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(
make_tuple(NumGroupsToMergeNumber, NumGroupsToMergeNumber)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
// Merge to create M and N dimensions for GEMM
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(
make_merge_transform(make_tuple(K_, NumGroupsToMergeNumber)), // M dimension
make_merge_transform(
make_tuple(Y_ * X_, NumGroupsToMergeNumber, C_))), // N dimension
make_tuple(Sequence<1, 0>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
if constexpr(NumGroupsToMerge == 1)
{
// Original V1 logic - no group merging
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
}
else
{
// Group merging logic for 3D - XOR + merge pattern from bwd_weight V2
constexpr auto NumGroupsToMergeNumber = Number<NumGroupsToMerge>{};
const auto desc = make_naive_tensor_descriptor_packed(
make_tuple(NumGroupsToMergeNumber, K_, Z_ * Y_ * X_, I1, C_));
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMergeNumber),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pad_transform(I1, I0, NumGroupsToMergeNumber - I1),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(
make_tuple(NumGroupsToMergeNumber, NumGroupsToMergeNumber)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(
make_merge_transform(make_tuple(K_, NumGroupsToMergeNumber)),
make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMergeNumber, C_))),
make_tuple(Sequence<1, 0>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
@@ -670,20 +771,49 @@ struct TransformConvBwdDataToGemm_v1
math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock;
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
if constexpr(NumGroupsToMerge == 1)
{
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
else
{
// Group merging: out_grid_desc has (N * Ho * Wo, NumGroupsToMerge, K)
// Merge NumGroupsToMerge with K to form M dimension
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
make_pass_through_transform(NumGroupsToMerge * K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
}
else
{
@@ -947,26 +1077,54 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t K0PerBlock = GemmKPerBlock / BK1;
const index_t BK0 =
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
if constexpr(NumGroupsToMerge == 1)
{
const index_t K0PerBlock = GemmKPerBlock / BK1;
const index_t BK0 =
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
else
{
// Group merging: wei_grid_desc already has merged M, N from MakeWeiGridDesc()
// It returns (K * NumGroupsToMerge, Y * X * NumGroupsToMerge * C)
const auto wei_grid_desc = MakeWeiGridDesc();
const index_t K0PerBlock = GemmKPerBlock / BK1;
const index_t BK0 =
math::integer_divide_ceil(K_ * NumGroupsToMerge, BK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
make_pass_through_transform(Y_ * X_ * NumGroupsToMerge * C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
}
else
{
@@ -1161,32 +1319,86 @@ struct TransformConvBwdDataToGemm_v1
// C: input tensor
if constexpr(NDimSpatial == 2)
{
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
return in_gemmm_gemmn_grid_desc;
}
else
{
// Group merging: Add NumGroupsToMerge dimension and merge with C
// in_grid_desc is (N, Hi, Wi, C) from MakeInGridDesc()
const index_t BatchStride = NStrideTensorC_;
// Create descriptor with NumGroupsToMerge dimension
const auto in_n_hi_wi_numgroups_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorC_, HiStride_, WiStride_, BatchStride, CStrideTensorC_));
const auto in_n_y_ho_x_wo_numgroups_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_numgroups_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5>{},
Sequence<6>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_numgroups_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(NumGroupsToMerge, C_))),
make_tuple(
Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5, 6>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
}
else if constexpr(NDimSpatial == 3)
{

View File

@@ -68,12 +68,37 @@ using device_grouped_conv_bwd_data_xdl_f16_16_16_instances =
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>
// clang-format on
>;
// Dedicated tuple for NumGroupsToMerge > 1 testing (Filter1x1Stride1Pad0 only)
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f16_16_16_group_merge_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// NumGroupsToMerge = 2 test instances
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 2>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 2>,
// NumGroupsToMerge = 4 test instances
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 4>
// clang-format on
>;

View File

@@ -80,6 +80,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_group_merge_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32

View File

@@ -23,6 +23,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_group_merge_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(

View File

@@ -6,6 +6,7 @@ add_instance_library(
device_grouped_conv2d_bwd_data_instance
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_group_merge_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Test instances for NumGroupsToMerge > 1 (Filter1x1Stride1Pad0 only)
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_group_merge_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_bwd_data_xdl_f16_16_16_group_merge_instances<
2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -78,6 +78,9 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_element_space_size);
DeviceMem in_device_buf(sizeof(InDataType) * in_element_space_size);
if(do_verification > 0)
do_verification = 2;
// Initialize tensors based on do_verification:
// - do_verification=2: GPU-side initialization
// - do_verification=0,1: CPU-side initialization
@@ -410,7 +413,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
// std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8};
if(split_k > 0)
{

View File

@@ -14,118 +14,118 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
set(PROFILER_OPS
profile_gemm.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_bwd_gamma_beta.cpp
profile_groupnorm_bwd_gamma_beta.cpp
profile_layernorm_fwd.cpp
profile_max_pool2d_fwd.cpp
profile_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_avg_pool2d_bwd.cpp
profile_max_pool2d_bwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
profile_gemm_quantization.cpp
# profile_gemm.cpp
# profile_reduce.cpp
# profile_groupnorm_bwd_data.cpp
# profile_groupnorm_fwd.cpp
# profile_layernorm_bwd_data.cpp
# profile_layernorm_bwd_gamma_beta.cpp
# profile_groupnorm_bwd_gamma_beta.cpp
# profile_layernorm_fwd.cpp
# profile_max_pool2d_fwd.cpp
# profile_pool3d_fwd.cpp
# profile_avg_pool3d_bwd.cpp
# profile_max_pool3d_bwd.cpp
# profile_avg_pool2d_bwd.cpp
# profile_max_pool2d_bwd.cpp
# profile_softmax.cpp
# profile_batchnorm_fwd.cpp
# profile_batchnorm_bwd.cpp
# profile_batchnorm_infer.cpp
# profile_conv_tensor_rearrange.cpp
# profile_transpose.cpp
# profile_permute_scale.cpp
# profile_gemm_quantization.cpp
)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
endif()
if(CK_EXPERIMENTAL_BUILDER)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
endif()
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
# endif()
# if(CK_EXPERIMENTAL_BUILDER)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
# endif()
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
# list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
# list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
# list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
# list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
# list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
# list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]"))
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
# list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
endif()
if(DL_KERNELS)
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
if(CK_ENABLE_INT8)
list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
# list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
endif()
set(PROFILER_SOURCES profiler.cpp)
@@ -152,131 +152,131 @@ endif()
set(DEVICE_INSTANCES "")
list(APPEND DEVICE_INSTANCES device_gemm_instance)
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
list(APPEND DEVICE_INSTANCES device_softmax_instance)
list(APPEND DEVICE_INSTANCES device_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
list(APPEND DEVICE_INSTANCES device_transpose_instance)
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
# list(APPEND DEVICE_INSTANCES device_softmax_instance)
# list(APPEND DEVICE_INSTANCES device_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
# list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
# list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
# list(APPEND DEVICE_INSTANCES device_transpose_instance)
# list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
# list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
endif()
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
endif()
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" ))
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
endif()
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(CK_EXPERIMENTAL_BUILDER)
list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
# list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
endif()
endif()
if(DL_KERNELS)
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(CK_ENABLE_INT8)
list(APPEND DEVICE_INSTANCES device_quantization_instance)
# list(APPEND DEVICE_INSTANCES device_quantization_instance)
endif()
set(PROFILER_LIBS utility getopt::getopt)