mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[rocm-libraries] ROCm/rocm-libraries#8518 (commit 1ad69c3)
[CK] Add support for large tensor index handling into conv bwd data (#8518) ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b5713be6cd
commit
65bef78383
@@ -59,130 +59,133 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
// offset base pointer for each work-group
|
||||
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
|
||||
|
||||
index_t left = 0;
|
||||
index_t right = gemms_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
|
||||
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
|
||||
left <= right)
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<EGlobalMemoryDataOperation>())
|
||||
{
|
||||
if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
// offset base pointer for each work-group
|
||||
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block);
|
||||
|
||||
if constexpr(GridwiseGemm::DirectLoadEnabled)
|
||||
{
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
|
||||
|
||||
index_t left = 0;
|
||||
index_t right = gemms_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
|
||||
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::DirectLoadEnabled)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
const auto a_grid_desc_ak0_m_ak1_transformed =
|
||||
GridwiseGemm::template TransformGrid<AGridDesc_AK0_M_AK1,
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_);
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1_transformed,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
false,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1_transformed,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
const auto a_grid_desc_ak0_m_ak1_transformed =
|
||||
GridwiseGemm::template TransformGrid<AGridDesc_AK0_M_AK1,
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_);
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1_transformed,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1_transformed),
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
false,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1_transformed,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
false,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
false,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
k_idx,
|
||||
gridDim.z,
|
||||
blockIdx.x - gemm_kernel_args[group_id].BlockStart_);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -192,7 +195,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
|
||||
#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#endif // end of if (defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -968,17 +971,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
|
||||
{
|
||||
if(arg.stride_overflow)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "Unsupported! stride_overflow is set but LargeTensors is not enabled!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_warp_size() != 64)
|
||||
{
|
||||
// TODO: Enable for warp size 32
|
||||
return false;
|
||||
}
|
||||
// Reject if the current warp size has no valid XDL configuration
|
||||
// Warp size 32 is temporary not supported but leave it for the future
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(MXdlPerWave64 == 0)
|
||||
@@ -994,6 +998,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
|
||||
}
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<AComputeType,
|
||||
BComputeType,
|
||||
Wave32MaxMNPerXdl,
|
||||
Wave32MaxMNPerXdl>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check device
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
@@ -1009,8 +1021,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
std::cout << "SplitK(" << arg.k_batch_ << ") tests are not supported!" << " In "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
@@ -1021,8 +1034,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
std::cout << "SplitK(" << arg.k_batch_ << ") tests are not supported!" << " In "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
@@ -1638,6 +1652,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3"
|
||||
<< (get_warp_size() != 64 ? "_WmmaPorted" : "")
|
||||
<< (DirectLoad ? "_DirectLoad" : "")
|
||||
<< (LargeTensors ? "_Large_Tensor" : "")
|
||||
<< "<"
|
||||
|
||||
@@ -1440,7 +1440,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if constexpr(!LargeTensors)
|
||||
{
|
||||
if(arg.stride_overflow)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "Unsupported! stride_overflow is set but LargeTensors is not enabled!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check device
|
||||
@@ -2023,6 +2032,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
if(get_warp_size() != 64) {
|
||||
str << "_WmmaPorted";
|
||||
}
|
||||
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
|
||||
@@ -1482,7 +1482,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
if constexpr(!LargeTensors)
|
||||
{
|
||||
if(arg.stride_overflow)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "Unsupported! stride_overflow is set but LargeTensors is not enabled!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
@@ -21,11 +21,19 @@ class TestGroupedConvndBwdData : public ::testing::Test
|
||||
using InLayout = std::tuple_element_t<3, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
std::vector<ck::index_t> split_ks{1};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
if(ck::is_wmma_supported())
|
||||
{
|
||||
GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32";
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto split_k : split_ks)
|
||||
|
||||
@@ -47,6 +47,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
if(ck::is_wmma_supported())
|
||||
{
|
||||
GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32";
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
|
||||
|
||||
@@ -24,6 +24,14 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
if(ck::is_wmma_supported())
|
||||
{
|
||||
GTEST_SKIP() << "Skipping test: WMMA architecture doesn't support FP32";
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
|
||||
Reference in New Issue
Block a user