mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Grouped Conv Bwd Weight Direct Load (#3648)
* Grouped Conv Bwd Weight Direct Load * Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp * Implement group merging for bwd_weight and add instances * Link direct load instances * builder fixes * fix * fixes * fix --------- Co-authored-by: Graner, Johannes <johannes.graner@amd.com>
This commit is contained in:
@@ -30,7 +30,8 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
bool TransposeC = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlops_pipeline_base
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -385,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
LdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
@@ -395,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
LdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
|
||||
@@ -32,9 +32,15 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
// Supported for Direct Load and V1
|
||||
if constexpr(LdsScalarLoadToVgpr)
|
||||
{
|
||||
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
|
||||
}
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
@@ -58,7 +64,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
@@ -758,7 +758,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
index_t KPacks,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -781,9 +782,9 @@ template <index_t BlockSize,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
index_t KPack,
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
bool LdsScalarLoadToVgpr>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -803,7 +804,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -822,7 +824,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -843,7 +847,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
@@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
"Direct load transfer does not support datatypes conversion. Source and "
|
||||
"destination data types must be the same.");
|
||||
|
||||
static_assert(
|
||||
DstVectorDim == nDim - 1,
|
||||
"Direct load transfer requires the destination vector dimension to be the last one.");
|
||||
|
||||
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
|
||||
"When loading more than one element per thread at once, the contiguous "
|
||||
"dimension must be the same between source and destination.");
|
||||
|
||||
Reference in New Issue
Block a user