Add support for fwd conv in gridwise implementation. Identical to run function for bwd data.

This commit is contained in:
kiefer
2025-08-20 08:56:53 +00:00
parent ccf696ad4e
commit 16920dee0f
2 changed files with 152 additions and 10 deletions

View File

@@ -663,6 +663,131 @@ struct GridwiseGemm_wmma_cshuffle_v3
karg.b_element_op,
karg.cde_element_op);
}
// Run method for convolution (grid descriptors are passed as arguments,
// not generated internally)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(void* p_shared,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t num_k_per_block,
Argument& karg)
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
const index_t k_idx =
__builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
b_scale_struct,
karg.KBatch,
k_idx);
}
};
} // namespace ck

View File

@@ -162,11 +162,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
// Calculate grid size taking into account splitk (KBatch)
// 2D grid (x,z)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
}
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
// 3D grid (x,y,z)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
@@ -594,8 +603,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
__device__ __host__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
{
return generate_tuple(
[&](auto i) {
@@ -918,8 +929,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
KPack>())>;
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
__host__ __device__ static constexpr auto
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
{
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
de_grid_desc_m_n,
@@ -1180,6 +1193,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
// Note: arguments k_batch and k_id should be set if splitk is used
// with implicit gemm (no pointer shift but shift using tensor descriptors)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -1205,7 +1220,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
BScaleStruct& b_scale_struct)
BScaleStruct& b_scale_struct,
const index_t k_batch = 1,
const index_t k_id = 0)
{
const auto as_grid_buf = generate_tuple(
[&](auto i) {
@@ -1253,7 +1270,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
if constexpr(NumATensor > 1)
{
const auto idx_as_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
[&](auto) { return make_multi_index(k_id, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
@@ -1307,7 +1324,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
true,
BlockwiseGemmPipe::GlobalBufferNum>(
as_grid_desc_ak0_m_ak1[I0],
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1323,7 +1340,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
if constexpr(NumBTensor > 1)
{
const auto idx_bs_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
[&](auto) { return make_multi_index(k_id, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
@@ -1377,7 +1394,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
true,
BlockwiseGemmPipe::GlobalBufferNum>(
bs_grid_desc_bk0_n_bk1[I0],
make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
@@ -1411,7 +1428,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
(KPerBlock * k_batch));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),