mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Grouped Convolution Backward Data Direct Load (#6624)
## Proposed changes Add Grouped Convolution Backward Data with Direct Load into DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 device implementation. This enables direct global memory loading (bypassing LDS) for the backward data convolution path on gfx950, following the same pattern used in both backward weight and forward convolution. Direct load convolution backward data improves performance by avoiding LDS round-trips for certain configurations on gfx950, which supports a wider range of instructions. Currently correctness is checked only at usage point, but should be extended to a standalone UT in the future.
This commit is contained in:
@@ -33,8 +33,9 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool TransposeC = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlops_pipeline_base
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -389,7 +390,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
LdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
ALdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
@@ -399,7 +400,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
LdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
BLdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
|
||||
@@ -32,12 +32,13 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool DirectLoad = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool DirectLoad = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
// Supported for Direct Load and V1
|
||||
if constexpr(LdsScalarLoadToVgpr)
|
||||
if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr)
|
||||
{
|
||||
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
|
||||
}
|
||||
@@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>{};
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
@@ -747,7 +747,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -772,7 +773,8 @@ template <index_t BlockSize,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
bool LdsScalarLoadToVgpr>
|
||||
bool ALdsScalarLoadToVgpr,
|
||||
bool BLdsScalarLoadToVgpr>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -793,7 +795,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -814,7 +817,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NRepeat,
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -837,7 +841,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NRepeat,
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>;
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
? 4 / sizeof(BDataType)
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
static constexpr bool ALdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
static constexpr bool BLdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
|
||||
// Note: Direct load use layout to create proper block and mmtile descriptor
|
||||
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor>,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor>,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
DirectLoad>;
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ template <typename ALayout,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
: public GridwiseGemm_xdl_cshuffle_base<
|
||||
ALayout,
|
||||
@@ -249,19 +251,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
|
||||
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
|
||||
{
|
||||
|
||||
if constexpr(!DirectLoad)
|
||||
{
|
||||
return desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t K = desc.GetLength(I0) * desc.GetLength(I2);
|
||||
const index_t MN = desc.GetLength(I1);
|
||||
|
||||
const auto desc_unmerged = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)),
|
||||
make_pass_through_transform(MN),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto desc_permuted = transform_tensor_descriptor(
|
||||
desc_unmerged,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(MN, K0Number)),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)),
|
||||
make_pass_through_transform(MN),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave,
|
||||
index_t MNWaves,
|
||||
index_t MNPerXdl,
|
||||
bool IsKContinous,
|
||||
typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
if constexpr(DirectLoad && IsKContinous)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto desc = transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
@@ -270,7 +343,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave,
|
||||
MWaves,
|
||||
MPerXdl,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value>(
|
||||
ABlockDesc_AK0_M_AK1{});
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_BK0_N_BK1>
|
||||
@@ -279,7 +356,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave,
|
||||
NWaves,
|
||||
NPerXdl,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value>(
|
||||
BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
struct Problem
|
||||
@@ -366,9 +447,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
|
||||
{
|
||||
@@ -389,9 +479,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
|
||||
{
|
||||
@@ -410,34 +509,35 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
|
||||
// or N dimension)
|
||||
static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
|
||||
using BlockwiseGemmPipe = remove_cvref_t<
|
||||
decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
// static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
|
||||
using BlockwiseGemmPipe = remove_cvref_t<
|
||||
decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad,
|
||||
LdsScalarLoadToVgpr>())>;
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>())>;
|
||||
|
||||
template <typename DeviceArch>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
|
||||
@@ -517,8 +617,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
@@ -535,8 +636,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
|
||||
make_multi_index(static_cast<index_t>(blockIdx.x)));
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
@@ -570,23 +671,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
|
||||
ABlockTransferSrcScalarPerVector >
|
||||
(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -626,23 +723,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_b_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
|
||||
BBlockTransferSrcScalarPerVector >
|
||||
(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -750,8 +843,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
@@ -771,7 +865,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
|
||||
make_multi_index(static_cast<index_t>(blockIdx.x)));
|
||||
make_multi_index(static_cast<index_t>(block_idx_x)));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
@@ -805,23 +899,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
|
||||
ABlockTransferSrcScalarPerVector >
|
||||
(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -861,23 +951,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_b_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
|
||||
BBlockTransferSrcScalarPerVector >
|
||||
(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user