mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Optimize grouped conv bwd wei split_k off calc
This commit is contained in:
@@ -55,13 +55,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
@@ -77,18 +84,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
@@ -113,14 +127,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
@@ -139,8 +160,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
@@ -148,10 +169,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
@@ -779,6 +807,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
e_in_transpose_desc_.GetLength(I1)}
|
||||
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
|
||||
ce_grid_desc_m_n_.GetLength(I1)};
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
split_k_stride_a_ =
|
||||
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
|
||||
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
@@ -864,6 +909,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_a_hack_, split_k_offset_b_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -966,7 +1014,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_a_hack_,
|
||||
arg.split_k_offset_b_hack_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -982,7 +1034,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_a_hack_,
|
||||
arg.split_k_offset_b_hack_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1886,14 +1942,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -84,7 +88,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -99,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = batch_count;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
@@ -634,6 +646,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
split_k_stride_a_ =
|
||||
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
|
||||
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
@@ -727,6 +756,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_a_hack_, split_k_offset_b_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -873,7 +905,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_a_hack_,
|
||||
arg.split_k_offset_b_hack_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
@@ -53,13 +53,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -74,15 +81,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -91,6 +101,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
|
||||
#endif // end of if (defined(__gfx9__)
|
||||
}
|
||||
|
||||
@@ -114,14 +129,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -140,8 +162,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
@@ -149,7 +171,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_a_hack,
|
||||
split_k_offset_b_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -158,6 +183,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_a_hack;
|
||||
ignore = split_k_offset_b_hack;
|
||||
#endif // end of if (defined(__gfx9__)
|
||||
}
|
||||
|
||||
@@ -594,6 +623,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
c_grid_desc_m_n_,
|
||||
GridwiseGemm64::CalculateMBlock(GemmM),
|
||||
GridwiseGemm64::CalculateNBlock(GemmN));
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
split_k_stride_a_ =
|
||||
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
|
||||
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -626,6 +672,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_a_hack_, split_k_offset_b_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -715,7 +764,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_a_hack_,
|
||||
arg.split_k_offset_b_hack_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -731,7 +784,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_a_hack_,
|
||||
arg.split_k_offset_b_hack_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ template <typename ALayout,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t ABlockLdsExtraMCustom,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
@@ -53,7 +53,7 @@ template <typename ALayout,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t BBlockLdsExtraNCustom,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -673,12 +673,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -744,7 +750,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -775,7 +781,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1035,12 +1041,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const bool split_k_offset_a_hack = false,
|
||||
const bool split_k_offset_b_hack = false)
|
||||
{
|
||||
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -1106,7 +1118,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1137,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -646,6 +646,415 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_a_hack,
|
||||
bool split_k_offset_b_hack)
|
||||
{
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
const long_index_t split_k_offset_a =
|
||||
split_k_offset_a_hack ? k_batch_id * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b =
|
||||
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_space_size_divisor =
|
||||
split_k_offset_a_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
|
||||
const long_index_t b_space_size_divisor =
|
||||
split_k_offset_b_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid + split_k_offset_a,
|
||||
a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / a_space_size_divisor);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + split_k_offset_b,
|
||||
b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / b_space_size_divisor);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
if(!c_block_cluster_adaptor.ValidCTileIndex(
|
||||
make_tuple(block_work_idx[I1], block_work_idx[I2]),
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatA,
|
||||
FloatAAdjusted,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
make_multi_index(
|
||||
split_k_offset_a_hack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
FloatBAdjusted,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
make_multi_index(
|
||||
split_k_offset_b_hack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
K1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(K1,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAAdjusted,
|
||||
FloatBAdjusted,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatC*>(p_shared),
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
static_assert(M1 == MWave, "");
|
||||
static_assert(N1 == NWave, "");
|
||||
static_assert(M2 * M3 * M4 == MPerXdl, "");
|
||||
static_assert(N2 == NPerXdl, "");
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
|
||||
M1,
|
||||
M2,
|
||||
M3,
|
||||
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
|
||||
N1,
|
||||
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// LDS to global
|
||||
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatC, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun
|
||||
{c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
|
||||
constexpr auto nxdlperwave_forward_step =
|
||||
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
|
||||
constexpr auto nxdlperwave_backward_step =
|
||||
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
|
||||
|
||||
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
|
||||
constexpr auto mxdlperwave = mxdlperwave_iter;
|
||||
|
||||
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
|
||||
constexpr bool nxdlperwave_forward_sweep =
|
||||
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
|
||||
|
||||
constexpr index_t nxdlperwave_value =
|
||||
nxdlperwave_forward_sweep
|
||||
? nxdlperwave_iter
|
||||
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
|
||||
|
||||
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
|
||||
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_buf);
|
||||
|
||||
// make sure it's safe to do ds_read
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global
|
||||
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
// move on nxdlperwave dimension
|
||||
if constexpr(nxdlperwave_forward_sweep &&
|
||||
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
nxdlperwave_forward_step);
|
||||
}
|
||||
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
nxdlperwave_backward_step);
|
||||
}
|
||||
});
|
||||
|
||||
// move on mxdlperwave dimension
|
||||
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
|
||||
Reference in New Issue
Block a user