mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver.
This commit is contained in:
@@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
}
|
||||
};
|
||||
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
int max_occupancy = 0;
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
GridwiseGemm,
|
||||
Argument,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
|
||||
GridwiseGemm,
|
||||
Argument,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
@@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static ck::index_t GetMaxOccupancy()
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
return active_workgroups_per_cu.max_occupancy_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -118,8 +120,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
p_wei_grid_{p_wei_grid},
|
||||
split_k_{split_k}
|
||||
p_wei_grid_{p_wei_grid}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
@@ -143,6 +144,19 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
|
||||
if (split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else {
|
||||
split_k_ = split_k;
|
||||
}
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
|
||||
@@ -237,7 +251,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
bool is_filter_data_packed;
|
||||
CElementwiseGridDesc elementwise_desc_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
const ck::index_t split_k_;
|
||||
ck::index_t split_k_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -303,15 +317,6 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.split_k_ < 0)
|
||||
{
|
||||
// TODO: Add split-K autodeduction.
|
||||
// This will probably require adding interface to the GEMM operation for
|
||||
// querying the optimal split-K value, as we cannot easily access the actual GEMM kernel
|
||||
// from here.
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
Reference in New Issue
Block a user