Fix explicit conv bwd weight struct

This commit is contained in:
Enrico Degregori
2025-12-12 09:49:17 +00:00
parent 0c67e9731a
commit 29743bc0f4

View File

@@ -144,18 +144,39 @@ struct DeviceGroupedConvBwdWeight_Explicit
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if(split_k < 0)
if constexpr(IsTwoStageNeeded)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else
{
k_batch_ = split_k;
}
}
else
{
k_batch_ = split_k;
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else
#endif
{
k_batch_ = split_k;
}
}
if constexpr(IsTwoStageNeeded)
@@ -317,6 +338,16 @@ struct DeviceGroupedConvBwdWeight_Explicit
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if constexpr(!IsTwoStageNeeded)
{
if(arg.split_k_ < 0)
{
return false;
}
}
#endif
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())