mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Disable bwd weight split-k autodeduce for single stage kernels (#2856)
* Disable bwd weight split-k autodeduce for single stage kernels * update interface tests --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -144,18 +144,39 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
|
||||
if(split_k < 0)
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_ = split_k;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_ = split_k;
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
split_k_ = split_k;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
@@ -318,6 +339,16 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if constexpr(!IsTwoStageNeeded)
|
||||
{
|
||||
if(arg.split_k_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -671,6 +671,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
@@ -683,6 +684,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
grid_size);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -939,6 +941,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -553,6 +553,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
@@ -565,6 +566,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
grid_size);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -934,6 +936,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -524,6 +524,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -549,6 +550,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1275,6 +1277,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
|
||||
Reference in New Issue
Block a user