Grouped convolution backward data WMMA v3 implementation (#3460)

* Added device level implementation for bwd_data_wmma_v3.

* Added first instance of bwd_data_wmma_v3(f16).

* Add support for bwd data in gridwise implementation

Some changes are general for convolution and some are specific for bwd
data. We need to generalize them once we have fwd, bwd data and bwd
weight

* Initial device implementation of bwd data

* Remove unused template parameters in device impl

* Add one instance for different layout

initial check of device implementation

* Add tests for splitk and for different layouts

* Appended more instances to wmma_v3_f16.

* Added conv_2d bf16 wmma_v3 instances.

* Added conv_3d_bf16 wmma_v3_instances.

* Added conv_3d_f16_wmma_v3_instances.

* Added SplitN test cases for wmma.

* Conv3d_bwd_data_scale_wmma_v3 instances.

* Conv3d_bwd_data_bilinear_wmma_v3_instances

* Renaming the device level instances file to common name , since it is defined for different DataTypes.

* Renaming the instances and fixing typo

* Added the test cases to regression test list

* NCHW support for wmma_v3

* Examples for bf16 and f16 bwd_data_wmma_v3

* Added transpose conditons for device impl

* fixing bugs

* Added the gemm_args array implmentation

* WIP debug conv bwd

* fix splitk

* Grouped gemm fix

* Update CmakeLists with EOF

* Added more instances for tests

* Fixed the run time error in examples and removed 3d conv examples.

* Fixed a typo.

* Updated CmakeLists to removed the 3d convultion deleted files

* Added print error statements for unsupoorted argument

* Added the merge conflict related changes

* Fixed compilation error

* Fixed the InstanceFactory duplication error.

* Removed the print statements and added logs to Arg function

* All the merge conflict related errors resolved

* Added d_tensor tests.

* Added the missing example types of wmm_v3

* Merge error fix

* Corrected the instance name

* Reverted the bias relu change

* Revereted the transpose load local change

* Updated the regression test list with bwd_data_scale

* Revert "Revereted the transpose load local change"

This reverts commit 0b7281edb2bf008e407006690a00621174d9d19b.

* Revert "Merge error fix"

This reverts commit f3c85daa474b1b83d10c8a3ce077354e71d91a2b.

* Reverting the local change

* Added merge error fix

* Build error fix due to merge conflicts

* Added bias_relu example for wmma_v3

* Modified the main method in dtensor tests

* Updated the dtensor tests to pick all the shapes

* Updated the dtensor test shapes.

* Updated the mem operations in tests.

* Added reference func

* Fixed typos in device impl

* Added new header file and modified the include file for 3d tests

* Renamed the test file and added reference func call.

* clang format fix

* Added ignore params

* Modified device impl and tests

* Removed debug print statements and updated dtensor test shapes

* Fixing merge conflicts

* Fixing more merge conflicts

* Fixed copyrights

* Updated the tuned instances to bilinear and scale.

* Adding tuned instances to vanilla wmma_v3

* Removed all unused instances and modified test layouts.

* Cleaned up all instances , reverted back fwd fp16 instances and updated tuned fp16 instances.

* Fix clang format

* Updated tuned f16/-genric instances

* Formatting the instances file

* Fixed copyrights and clang issues

* Nonsense commit to force git to force

* Removed the transpose instances

* Added verified genric instances

* Fixing namespace errors

* Added todo for failing shapes

* Formatting instance file

* Fix instance list formatting

* Removing unnecessary formats

* Renamed the common file

* Unification of xdl and wmma bwd_data tests

* Updated Cmake

* Added all layout types and deleted code.

* Updated Cmake to add the condition to all tests.

---------

Co-authored-by: Enrico Degregori <enrico@streamhpc.com>
Co-authored-by: Anton Gorenko <anton@streamhpc.com>
Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
ApoorvaKalyani
2025-12-30 16:25:08 +01:00
committed by GitHub
parent dae85ead64
commit 53a1e4f551
42 changed files with 4593 additions and 171 deletions

View File

@@ -775,6 +775,147 @@ struct GridwiseGemm_wmma_cshuffle_v3
return Block2CTileMap{problem.M, problem.N, 4};
}
// Run method for convolution for bwd_data (grid descriptors are passed as arguments,
// not generated internally)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMapExt,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
bool CTranspose,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMapExt& block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t num_k_per_block,
Argument& karg,
EpilogueArgument& epilogue_args)
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
const index_t k_idx =
__builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_idx,
k_idx,
karg.KBatch);
}
// Run method for convolution (grid descriptors are passed as arguments,
// not generated internally)
template <typename AGridDesc_AK0_M_K1,