mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[WIP] initial implementation
This commit is contained in:
@@ -287,7 +287,8 @@ template <index_t NDimSpatial,
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType,
|
||||
index_t MaxTransposeTransferInScalarPerVector = 1,
|
||||
index_t MaxTransposeTransferOutScalarPerVector = 1>
|
||||
index_t MaxTransposeTransferOutScalarPerVector = 1,
|
||||
index_t NumGroupsToMerge = 1>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
@@ -387,7 +388,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType,
|
||||
1,
|
||||
NumGroupsToMerge,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
|
||||
@@ -964,10 +965,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum));
|
||||
gemms_grid_size_.push_back(grid_size);
|
||||
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
|
||||
// A/B/Ds/E Batch Stride (multiply by NumGroupsToMerge for group merging)
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ =
|
||||
a_g_n_k_wos_strides_transposed[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ =
|
||||
b_g_k_c_xs_strides_transposed[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ =
|
||||
e_g_n_c_wis_strides_transposed[0] * NumGroupsToMerge;
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
|
||||
@@ -1147,7 +1151,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
float ave_time = 0;
|
||||
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t gdy = arg.num_group_ / NumGroupsToMerge;
|
||||
const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_;
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
@@ -1542,6 +1546,21 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
// Validation: Check that NumGroupsToMerge divides ConvG evenly
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
if(ConvG % NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported! Conv_G % NumGroupsToMerge != 0: Conv_G=" << ConvG
|
||||
<< ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Specifialization
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -1908,7 +1927,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
<< TransposeTransferInScalarPerVectorAligned <<", "
|
||||
<< "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
|
||||
}
|
||||
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
str << ", NumGroupsToMerge: " << NumGroupsToMerge;
|
||||
}
|
||||
|
||||
str << ">";
|
||||
|
||||
|
||||
@@ -508,9 +508,19 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(WoStride_, KStrideTensorA_));
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(WoStride_, KStrideTensorA_));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Add NumGroupsToMerge dimension for group merging
|
||||
const index_t BatchStride = NStrideTensorA_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_),
|
||||
make_tuple(WoStride_, BatchStride, KStrideTensorA_));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -607,14 +617,105 @@ struct TransformConvBwdDataToGemm_v1
|
||||
|
||||
__host__ __device__ auto MakeWeiGridDesc() const
|
||||
{
|
||||
// Power-of-2 constraint for NumGroupsToMerge (required for XOR transform)
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
|
||||
NumGroupsToMerge == 64);
|
||||
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
// Original V1 logic - no group merging
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group merging logic - XOR + merge pattern from bwd_weight V2
|
||||
// Add NumGroupsToMerge for M dimension and 1 as placeholder for N dimension
|
||||
constexpr auto NumGroupsToMergeNumber = Number<NumGroupsToMerge>{};
|
||||
const auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(NumGroupsToMergeNumber, K_, Y_ * X_, I1, C_));
|
||||
|
||||
// Pad placeholder dimension from 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMergeNumber),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pad_transform(I1, I0, NumGroupsToMergeNumber - I1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
// XOR transform to select diagonal (where group indices match)
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(NumGroupsToMergeNumber, NumGroupsToMergeNumber)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
|
||||
|
||||
// Merge to create M and N dimensions for GEMM
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMergeNumber)), // M dimension
|
||||
make_merge_transform(
|
||||
make_tuple(Y_ * X_, NumGroupsToMergeNumber, C_))), // N dimension
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
// Original V1 logic - no group merging
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group merging logic for 3D - XOR + merge pattern from bwd_weight V2
|
||||
constexpr auto NumGroupsToMergeNumber = Number<NumGroupsToMerge>{};
|
||||
const auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(NumGroupsToMergeNumber, K_, Z_ * Y_ * X_, I1, C_));
|
||||
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMergeNumber),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pad_transform(I1, I0, NumGroupsToMergeNumber - I1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(NumGroupsToMergeNumber, NumGroupsToMergeNumber)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMergeNumber)),
|
||||
make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMergeNumber, C_))),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -670,20 +771,49 @@ struct TransformConvBwdDataToGemm_v1
|
||||
math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group merging: out_grid_desc has (N * Ho * Wo, NumGroupsToMerge, K)
|
||||
// Merge NumGroupsToMerge with K to form M dimension
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
make_pass_through_transform(NumGroupsToMerge * K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -947,26 +1077,54 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group merging: wei_grid_desc already has merged M, N from MakeWeiGridDesc()
|
||||
// It returns (K * NumGroupsToMerge, Y * X * NumGroupsToMerge * C)
|
||||
const auto wei_grid_desc = MakeWeiGridDesc();
|
||||
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(K_ * NumGroupsToMerge, BK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(Y_ * X_ * NumGroupsToMerge * C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1161,32 +1319,86 @@ struct TransformConvBwdDataToGemm_v1
|
||||
// C: input tensor
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
const auto in_gemmm_gemmn_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group merging: Add NumGroupsToMerge dimension and merge with C
|
||||
// in_grid_desc is (N, Hi, Wi, C) from MakeInGridDesc()
|
||||
const index_t BatchStride = NStrideTensorC_;
|
||||
|
||||
// Create descriptor with NumGroupsToMerge dimension
|
||||
const auto in_n_hi_wi_numgroups_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorC_, HiStride_, WiStride_, BatchStride, CStrideTensorC_));
|
||||
|
||||
const auto in_n_y_ho_x_wo_numgroups_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_numgroups_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_numgroups_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, C_))),
|
||||
make_tuple(
|
||||
Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5, 6>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
|
||||
@@ -68,12 +68,37 @@ using device_grouped_conv_bwd_data_xdl_f16_16_16_instances =
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Dedicated tuple for NumGroupsToMerge > 1 testing (Filter1x1Stride1Pad0 only)
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f16_16_16_group_merge_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// NumGroupsToMerge = 2 test instances
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 2>,
|
||||
// NumGroupsToMerge = 4 test instances
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, F16, F16, 1, 1, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, F16, F16, 1, 1, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -80,6 +80,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_group_merge_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
|
||||
@@ -23,6 +23,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_group_merge_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
|
||||
@@ -6,6 +6,7 @@ add_instance_library(
|
||||
device_grouped_conv2d_bwd_data_instance
|
||||
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_group_merge_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Test instances for NumGroupsToMerge > 1 (Filter1x1Stride1Pad0 only)
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_group_merge_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_group_merge_instances<
|
||||
2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -78,6 +78,9 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_element_space_size);
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_element_space_size);
|
||||
|
||||
if(do_verification > 0)
|
||||
do_verification = 2;
|
||||
|
||||
// Initialize tensors based on do_verification:
|
||||
// - do_verification=2: GPU-side initialization
|
||||
// - do_verification=0,1: CPU-side initialization
|
||||
@@ -410,7 +413,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
// std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8};
|
||||
|
||||
if(split_k > 0)
|
||||
{
|
||||
|
||||
@@ -14,118 +14,118 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
|
||||
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
|
||||
|
||||
set(PROFILER_OPS
|
||||
profile_gemm.cpp
|
||||
profile_reduce.cpp
|
||||
profile_groupnorm_bwd_data.cpp
|
||||
profile_groupnorm_fwd.cpp
|
||||
profile_layernorm_bwd_data.cpp
|
||||
profile_layernorm_bwd_gamma_beta.cpp
|
||||
profile_groupnorm_bwd_gamma_beta.cpp
|
||||
profile_layernorm_fwd.cpp
|
||||
profile_max_pool2d_fwd.cpp
|
||||
profile_pool3d_fwd.cpp
|
||||
profile_avg_pool3d_bwd.cpp
|
||||
profile_max_pool3d_bwd.cpp
|
||||
profile_avg_pool2d_bwd.cpp
|
||||
profile_max_pool2d_bwd.cpp
|
||||
profile_softmax.cpp
|
||||
profile_batchnorm_fwd.cpp
|
||||
profile_batchnorm_bwd.cpp
|
||||
profile_batchnorm_infer.cpp
|
||||
profile_conv_tensor_rearrange.cpp
|
||||
profile_transpose.cpp
|
||||
profile_permute_scale.cpp
|
||||
profile_gemm_quantization.cpp
|
||||
# profile_gemm.cpp
|
||||
# profile_reduce.cpp
|
||||
# profile_groupnorm_bwd_data.cpp
|
||||
# profile_groupnorm_fwd.cpp
|
||||
# profile_layernorm_bwd_data.cpp
|
||||
# profile_layernorm_bwd_gamma_beta.cpp
|
||||
# profile_groupnorm_bwd_gamma_beta.cpp
|
||||
# profile_layernorm_fwd.cpp
|
||||
# profile_max_pool2d_fwd.cpp
|
||||
# profile_pool3d_fwd.cpp
|
||||
# profile_avg_pool3d_bwd.cpp
|
||||
# profile_max_pool3d_bwd.cpp
|
||||
# profile_avg_pool2d_bwd.cpp
|
||||
# profile_max_pool2d_bwd.cpp
|
||||
# profile_softmax.cpp
|
||||
# profile_batchnorm_fwd.cpp
|
||||
# profile_batchnorm_bwd.cpp
|
||||
# profile_batchnorm_infer.cpp
|
||||
# profile_conv_tensor_rearrange.cpp
|
||||
# profile_transpose.cpp
|
||||
# profile_permute_scale.cpp
|
||||
# profile_gemm_quantization.cpp
|
||||
)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
endif()
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
|
||||
endif()
|
||||
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
# endif()
|
||||
# if(CK_EXPERIMENTAL_BUILDER)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
|
||||
# endif()
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
endif()
|
||||
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
|
||||
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]"))
|
||||
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
if(CK_ENABLE_INT8)
|
||||
list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
|
||||
# list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
|
||||
endif()
|
||||
|
||||
set(PROFILER_SOURCES profiler.cpp)
|
||||
@@ -152,131 +152,131 @@ endif()
|
||||
|
||||
|
||||
set(DEVICE_INSTANCES "")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
|
||||
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" ))
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(CK_ENABLE_INT8)
|
||||
list(APPEND DEVICE_INSTANCES device_quantization_instance)
|
||||
# list(APPEND DEVICE_INSTANCES device_quantization_instance)
|
||||
endif()
|
||||
|
||||
set(PROFILER_LIBS utility getopt::getopt)
|
||||
|
||||
Reference in New Issue
Block a user