Optimize grouped conv bwd wei split_k off calc

This commit is contained in:
Bartlomiej Kocot
2025-11-03 15:10:55 +00:00
parent ab1a8356b6
commit 6f61dd56c5
5 changed files with 609 additions and 47 deletions

View File

@@ -55,13 +55,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
[[maybe_unused]] const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
@@ -77,18 +84,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
#endif // end of if (defined(__gfx9__))
}
@@ -113,14 +127,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
[[maybe_unused]] const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
@@ -139,8 +160,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
@@ -148,10 +169,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
#endif // end of if (defined(__gfx9__))
}
@@ -779,6 +807,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
e_in_transpose_desc_.GetLength(I1)}
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
ce_grid_desc_m_n_.GetLength(I1)};
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
split_k_stride_a_ =
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
}
std::size_t GetWorkspaceATensorSizeBytes() const
@@ -864,6 +909,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_a_hack_, split_k_offset_b_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -966,7 +1014,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_a_hack_,
arg.split_k_offset_b_hack_);
}
else
{
@@ -982,7 +1034,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_a_hack_,
arg.split_k_offset_b_hack_);
}
};
@@ -1886,14 +1942,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}

View File

@@ -57,7 +57,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -84,7 +88,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = p_a_grid;
@@ -99,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = batch_count;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
@@ -634,6 +646,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
split_k_stride_a_ =
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
}
std::size_t GetWorkspaceATensorSizeBytes() const
@@ -727,6 +756,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_a_hack_, split_k_offset_b_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -873,7 +905,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
arg.b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
arg.compute_ptr_offset_of_batch_,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_a_hack_,
arg.split_k_offset_b_hack_);
};
if(has_main_k0_block_loop)

View File

@@ -53,13 +53,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -74,15 +81,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;
@@ -91,6 +101,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
#endif // end of if (defined(__gfx9__)
}
@@ -114,14 +129,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -140,8 +162,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
@@ -149,7 +171,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_a_hack,
split_k_offset_b_hack);
}
#else
ignore = karg;
@@ -158,6 +183,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
#endif // end of if (defined(__gfx9__)
}
@@ -594,6 +623,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
c_grid_desc_m_n_,
GridwiseGemm64::CalculateMBlock(GemmM),
GridwiseGemm64::CalculateNBlock(GemmN));
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
split_k_stride_a_ =
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
}
const ADataType* p_a_grid_;
@@ -626,6 +672,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_a_hack_, split_k_offset_b_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -715,7 +764,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_a_hack_,
arg.split_k_offset_b_hack_);
}
else
{
@@ -731,7 +784,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_a_hack_,
arg.split_k_offset_b_hack_);
}
};

View File

@@ -45,7 +45,7 @@ template <typename ALayout,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraMCustom,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
@@ -53,7 +53,7 @@ template <typename ALayout,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraNCustom,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -673,12 +673,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0)
const index_t k_id = 0,
const index_t k_batch = 1,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -744,7 +750,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -775,7 +781,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
@@ -1035,12 +1041,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0)
const index_t k_id = 0,
const index_t k_batch = 1,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -1106,7 +1118,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1137,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),

View File

@@ -646,6 +646,415 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CBlockClusterAdaptor& c_block_cluster_adaptor,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
{
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
const long_index_t split_k_offset_a =
split_k_offset_a_hack ? k_batch_id * split_k_stride_a : 0;
const long_index_t split_k_offset_b =
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
const long_index_t a_space_size_divisor =
split_k_offset_a_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
const long_index_t b_space_size_divisor =
split_k_offset_b_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid + split_k_offset_a,
a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / a_space_size_divisor);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + split_k_offset_b,
b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / b_space_size_divisor);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
FloatAAdjusted,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(
split_k_offset_a_hack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
FloatBAdjusted,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(
split_k_offset_b_hack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
K1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
K1 < 32))
? true
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(K1,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAAdjusted,
FloatBAdjusted,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXdl,
NPerXdl,
MRepeat,
NRepeat,
KPack,
ComputeTypeA,
ComputeTypeB>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_b_k0_n_k1_grid_desc,
b_b_k0_n_k1_block_desc,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// output: register to global memory
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
static_assert(N1 == NWave, "");
static_assert(M2 * M3 * M4 == MPerXdl, "");
static_assert(N2 == NPerXdl, "");
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXdl,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
}
});
}
}
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,