mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[Conv] Enable bwd weight splitk autodeduction with cap (#3656)
* Enable bwd weight splitk autodeduction with cap
* Fix error threshold calculations
* Add missing logic to wmma multiple d kernel
* Fix threshold calculation
* Update test with new applicability
[ROCm/composable_kernel commit: fabac7e2c3]
This commit is contained in:
@@ -11,8 +11,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -11,8 +11,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -162,7 +162,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
@@ -171,9 +170,11 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -338,16 +339,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if constexpr(!IsTwoStageNeeded)
|
||||
{
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return;
|
||||
}
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
int max_occupancy = 0;
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// TODO: implement
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(
|
||||
@@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
@@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
|
||||
@@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
@@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
@@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
|
||||
@@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
@@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -594,7 +594,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -611,6 +610,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
|
||||
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
@@ -620,7 +622,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1399,13 +1400,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
// check device
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
|
||||
@@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
|
||||
// Calculate number of accumulations accounting for split_k
|
||||
const int num_accums =
|
||||
static_cast<int>(output.GetElementSize() / conv_param.K_ / split_k_value);
|
||||
|
||||
// Additional tolerance for split_k accumulation if needed
|
||||
int total_accums = num_accums;
|
||||
if(split_k_value > 1)
|
||||
{
|
||||
total_accums = std::max(num_accums, static_cast<int>(split_k_value));
|
||||
}
|
||||
|
||||
// Perform GPU verification (max value computed internally on GPU)
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Get maximum accumulated value from reference
|
||||
const std::size_t tensor_size =
|
||||
weight_device_result.mDesc.GetElementSpaceSize();
|
||||
max_accumulated_value =
|
||||
gpu_reduce_max<WeiDataType>(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size);
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
auto atol =
|
||||
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k,
|
||||
num_accums / num_accums_split_k);
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
num_accums_split_k);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
// Perform GPU verification
|
||||
auto gpu_result =
|
||||
ck::profiler::gpu_verify<WeiDataType, ComputeType, AccDataType>(
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_wei_buf.GetDeviceBuffer(),
|
||||
total_accums,
|
||||
tensor_size);
|
||||
ck::profiler::gpu_verify<WeiDataType>(wei_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_wei_buf.GetDeviceBuffer(),
|
||||
rtol,
|
||||
atol,
|
||||
tensor_size);
|
||||
|
||||
if(!gpu_result)
|
||||
{
|
||||
|
||||
@@ -184,5 +184,5 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce)
|
||||
this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
this->split_k_ = -1;
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user