mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4421 (commit 5bb5769)
[CK] Unify the grouped convolution gridwise Run() functions (#4421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation There are currently three different grouped convolution related Run() function overloads that exist in `gridwise_gemm_wmma_cshuffle_v3.hpp`. These are used for the different types of grouped convolution: Forward, Backward weights, and Backward data. The functions are very similar and should be unified to a single `Run()` function for all types of grouped convolution. ## Technical Details The three old `Run<>()` functions were replaced with a single unified function. The new `Run<>()` function is run from device implementations: - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 - DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 The DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor implementation uses a different `Run<>()` overload and was therefore not modified. ## Test Plan Run the following grouped convolution tests on `gfx1201`, as this architecture is WMMA-capable: - `test_grouped_convnd_fwd` - `test_grouped_convnd_bwd_weight` - `test_grouped_convnd_bwd_data` Compilation and testing were also executed on `gfx1100` to avoid CI problems. ## Test Result First part (unification of `Run<>()` function): All tests successful. Second part (integration of single `Run<>()` function as a direct call): All tests successful. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6f0ecf361e
commit
d8ee107a47
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/array.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
@@ -844,313 +846,126 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return Block2CTileMap{M, N, 4};
|
||||
}
|
||||
|
||||
// Run method for convolution for bwd_data (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
// Grouped convolution regime
|
||||
enum class ConvRegime
|
||||
{
|
||||
BWD_DATA,
|
||||
BWD_WEIGHT,
|
||||
FORWARD
|
||||
};
|
||||
|
||||
// Unified Run<>() function for all regimes (bwd_data, bwd_weight, fwd)
|
||||
template <ConvRegime Regime,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMapExt,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, // Defined for bwd_data &
|
||||
// fwd convolution
|
||||
typename CEGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMapExt, // Defined for bwd_data convolution
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
typename ComputePtrOffsetOfN, // Defined for bwd_data & fwd convolution
|
||||
index_t NumGroupsToMerge, // Defined for bwd_weight convolution
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
bool CTranspose,
|
||||
InMemoryDataOperationEnum GlobalMemoryDataOperation,
|
||||
bool CTranspose, // Defined for bwd_data convolution
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const CEGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ce_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMapExt& block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
|
||||
const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
|
||||
const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
|
||||
const index_t k_idx =
|
||||
__builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
|
||||
|
||||
// offset base pointer for each work-group
|
||||
// Resolve the current regime at compile time:
|
||||
constexpr bool is_bwd_data = (Regime == ConvRegime::BWD_DATA);
|
||||
constexpr bool is_bwd_weight = (Regime == ConvRegime::BWD_WEIGHT);
|
||||
constexpr bool is_fwd = (Regime == ConvRegime::FORWARD);
|
||||
|
||||
// ======== Index =========
|
||||
const auto g_idx = [&]() -> index_t {
|
||||
if constexpr(is_bwd_data || is_fwd)
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t n_idx =
|
||||
(is_bwd_data || is_fwd) ? __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch) : 0;
|
||||
|
||||
// Using a lambda for better clang compliance than nested ternary operators
|
||||
const auto k_idx = [&]() -> index_t {
|
||||
if constexpr(is_bwd_data)
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) *
|
||||
num_k_per_block);
|
||||
}
|
||||
else if constexpr(is_bwd_weight)
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
|
||||
// ======== Offset ========
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
|
||||
const long_index_t b_batch_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
(!CTranspose) ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx))
|
||||
: 0;
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) =
|
||||
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
|
||||
});
|
||||
// b_n_offset
|
||||
const auto b_n_offset = [&]() -> long_index_t {
|
||||
if constexpr(is_bwd_data)
|
||||
{
|
||||
return CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx))
|
||||
: 0;
|
||||
}
|
||||
else if constexpr(is_fwd)
|
||||
{
|
||||
return amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) =
|
||||
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
|
||||
});
|
||||
|
||||
DsGridPointer p_ds_grid_grp;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
// Currently supporting one A and one B
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct (Empty)
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = AScale{};
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
p_ds_grid_grp,
|
||||
karg.p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
|
||||
// Run method for convolution (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
index_t NumGroupsToMerge,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = Scale{};
|
||||
auto a_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
|
||||
// Run method for convolution fwd (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
|
||||
// offset base pointer for each work-group
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
const auto e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
// ======== Grid pointers ======== //
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
@@ -1167,12 +982,34 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
DsGridPointer p_ds_grid_grp;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
|
||||
ds_batch_offset[i] + ds_n_offset[i];
|
||||
if constexpr(is_bwd_data)
|
||||
{
|
||||
p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i];
|
||||
}
|
||||
else if constexpr(is_fwd)
|
||||
{
|
||||
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
|
||||
ds_batch_offset[i] + ds_n_offset[i];
|
||||
}
|
||||
});
|
||||
|
||||
// Currently supporting one A and one B
|
||||
// ======== Grid descriptors ======== //
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = [&]() {
|
||||
if constexpr(is_bwd_weight)
|
||||
{
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
return MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
}
|
||||
}();
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
@@ -1187,51 +1024,62 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
// ======== Tiling ======== //
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(ce_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
ce_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// ======== Remaining Run() arguments ======== //
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct (Empty)
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = AScale{};
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = Scale{};
|
||||
auto b_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
// p_ds_grid_
|
||||
const auto p_ds_grid_ = (is_bwd_data || is_fwd) ? p_ds_grid_grp : karg.p_ds_grid;
|
||||
|
||||
// p_e_grid_
|
||||
const auto p_e_grid_ = karg.p_e_grid + e_batch_offset + e_n_offset;
|
||||
|
||||
// Final arguments
|
||||
const index_t A_k_id = k_idx;
|
||||
const index_t B_k_id = k_idx;
|
||||
const index_t k_batch = (is_bwd_data || is_bwd_weight) ? karg.KBatch : 1;
|
||||
|
||||
// ======= Call the Run() function ======== //
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ce_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
p_ds_grid_grp,
|
||||
karg.p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_ds_grid_,
|
||||
p_e_grid_,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
ce_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
@@ -1240,7 +1088,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id,
|
||||
k_batch);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user