[rocm-libraries] ROCm/rocm-libraries#8518 (commit 1ad69c3)

[CK] Add support for large tensor index handling into conv
 bwd data (#8518)

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
jakpiase
2026-06-17 15:51:36 +00:00
committed by assistant-librarian[bot]
parent b5713be6cd
commit 65bef78383
6 changed files with 190 additions and 129 deletions

View File

@@ -59,130 +59,133 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
{
#if defined(__gfx9__)
// offset base pointer for each work-group
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
left <= right)
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<EGlobalMemoryDataOperation>())
{
if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
// offset base pointer for each work-group
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block);
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
left <= right)
{
if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
#if defined(__gfx950__)
const auto a_grid_desc_ak0_m_ak1_transformed =
GridwiseGemm::template TransformGrid<AGridDesc_AK0_M_AK1,
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_);
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
true,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1_transformed,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
else
{
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
false,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1_transformed,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
const auto a_grid_desc_ak0_m_ak1_transformed =
GridwiseGemm::template TransformGrid<AGridDesc_AK0_M_AK1,
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_);
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
true,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1_transformed,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
else
{
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
false,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1_transformed,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
#endif
}
else
{
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
true,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
false,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
true,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
false,
EGlobalMemoryDataOperation,
TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
k_idx,
gridDim.z,
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
}
}
}
#else
@@ -192,7 +195,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
#endif // end of if (defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__))
}
} // namespace
@@ -968,17 +971,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
{
if(arg.stride_overflow)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout
<< "Unsupported! stride_overflow is set but LargeTensors is not enabled!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
if(get_warp_size() != 64)
{
// TODO: Enable for warp size 32
return false;
}
// Reject if the current warp size has no valid XDL configuration
// Warp size 32 is temporary not supported but leave it for the future
if(get_warp_size() == 64)
{
if constexpr(MXdlPerWave64 == 0)
@@ -994,6 +998,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
}
}
if(!ck::is_xdl_wmma_supported<AComputeType,
BComputeType,
Wave32MaxMNPerXdl,
Wave32MaxMNPerXdl>())
{
return false;
}
// check device
if constexpr(DirectLoad)
{
@@ -1009,8 +1021,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
std::cout << "SplitK(" << arg.k_batch_ << ") tests are not supported!" << " In "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
@@ -1021,8 +1034,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
std::cout << "SplitK(" << arg.k_batch_ << ") tests are not supported!" << " In "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
@@ -1638,6 +1652,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
// clang-format off
str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3"
<< (get_warp_size() != 64 ? "_WmmaPorted" : "")
<< (DirectLoad ? "_DirectLoad" : "")
<< (LargeTensors ? "_Large_Tensor" : "")
<< "<"

View File

@@ -1440,7 +1440,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(!LargeTensors)
{
if(arg.stride_overflow)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout
<< "Unsupported! stride_overflow is set but LargeTensors is not enabled!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
// check device
@@ -2023,6 +2032,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// clang-format off
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
if(get_warp_size() != 64) {
str << "_WmmaPorted";
}
if constexpr(DirectLoad) {
str << "_DirectLoad";
}

View File

@@ -1482,7 +1482,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
if constexpr(!LargeTensors)
{
if(arg.stride_overflow)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout
<< "Unsupported! stride_overflow is set but LargeTensors is not enabled!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
namespace ctc = tensor_layout::convolution;

View File

@@ -21,11 +21,19 @@ class TestGroupedConvndBwdData : public ::testing::Test
using InLayout = std::tuple_element_t<3, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
std::vector<ck::index_t> split_ks{1, 2};
std::vector<ck::index_t> split_ks{1};
template <ck::index_t NDimSpatial>
void Run()
{
if constexpr(std::is_same_v<DataType, float>)
{
if(ck::is_wmma_supported())
{
GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32";
}
}
EXPECT_FALSE(conv_params.empty());
bool pass = true;
for(auto split_k : split_ks)

View File

@@ -47,6 +47,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
template <ck::index_t NDimSpatial>
void Run()
{
if constexpr(std::is_same_v<DataType, float>)
{
if(ck::is_wmma_supported())
{
GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32";
}
}
EXPECT_FALSE(conv_params.empty());
bool pass = true;

View File

@@ -24,6 +24,14 @@ class TestGroupedConvndFwd : public ::testing::Test
template <ck::index_t NDimSpatial>
void Run()
{
if constexpr(std::is_same_v<DataType, float>)
{
if(ck::is_wmma_supported())
{
GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32";
}
}
EXPECT_FALSE(conv_params.empty());
bool pass = true;
for(auto& param : conv_params)