Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver.

This commit is contained in:
Ville Pietilä
2025-07-29 12:32:40 +00:00
parent eb1ae702b7
commit 0eb78b53b1
2 changed files with 77 additions and 12 deletions

View File

@@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
}
};
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
// Invoker
struct Invoker : public BaseInvoker
{
@@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
return str.str();
}
static ck::index_t GetMaxOccupancy()
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
return active_workgroups_per_cu.max_occupancy_;
}
};
} // namespace device

View File

@@ -13,6 +13,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
namespace ck {
namespace tensor_operation {
@@ -118,8 +120,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
p_wei_grid_{p_wei_grid},
split_k_{split_k}
p_wei_grid_{p_wei_grid}
{
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
@@ -143,6 +144,19 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if (split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else {
split_k_ = split_k;
}
if constexpr(IsTwoStageNeeded)
{
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
@@ -237,7 +251,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
bool is_filter_data_packed;
CElementwiseGridDesc elementwise_desc_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
const ck::index_t split_k_;
ck::index_t split_k_;
};
// Invoker
@@ -303,15 +317,6 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.split_k_ < 0)
{
// TODO: Add split-K autodeduction.
// This will probably require adding interface to the GEMM operation for
// querying the optimal split-K value, as we cannot easily access the actual GEMM kernel
// from here.
return false;
}
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())