mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Grouped conv fwd with direct load (#3082)
* Grouped conv fwd with direct load * fix * fix * Add IsSupported check * Fix * fix inductor
This commit is contained in:
@@ -120,21 +120,56 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault;
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_group_offset + a_n_offset,
|
||||
karg.p_b_grid + b_group_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
block_2_ctile_map,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_m_n,
|
||||
c_grid_desc_m_n);
|
||||
if constexpr(GridwiseGemm::DirectLoadEnabled)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_group_offset + a_n_offset,
|
||||
karg.p_b_grid + b_group_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
block_2_ctile_map,
|
||||
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
|
||||
GridwiseGemm::BK0Number,
|
||||
GridwiseGemm::BK1Number>(
|
||||
b_grid_desc_bk0_n_bk1),
|
||||
ds_grid_desc_m_n,
|
||||
c_grid_desc_m_n);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_group_offset + a_n_offset,
|
||||
karg.p_b_grid + b_group_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
block_2_ctile_map,
|
||||
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
|
||||
GridwiseGemm::BK0Number,
|
||||
GridwiseGemm::BK1Number>(
|
||||
b_grid_desc_bk0_n_bk1),
|
||||
ds_grid_desc_m_n,
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -208,22 +243,58 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault;
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_group_offset + a_n_offset,
|
||||
karg.p_b_grid + b_group_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_group_offset + e_n_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
block_2_ctile_map,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_m_n,
|
||||
c_grid_desc_m_n);
|
||||
if constexpr(GridwiseGemm::DirectLoadEnabled)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_group_offset + a_n_offset,
|
||||
karg.p_b_grid + b_group_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_group_offset + e_n_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
block_2_ctile_map,
|
||||
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
|
||||
GridwiseGemm::BK0Number,
|
||||
GridwiseGemm::BK1Number>(
|
||||
b_grid_desc_bk0_n_bk1),
|
||||
ds_grid_desc_m_n,
|
||||
c_grid_desc_m_n);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_group_offset + a_n_offset,
|
||||
karg.p_b_grid + b_group_offset,
|
||||
p_ds_grid_grp,
|
||||
karg.p_c_grid + e_group_offset + e_n_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op,
|
||||
block_2_ctile_map,
|
||||
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::AK0Number,
|
||||
GridwiseGemm::AK1Number>(
|
||||
a_grid_desc_ak0_m_ak1),
|
||||
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
|
||||
GridwiseGemm::BK0Number,
|
||||
GridwiseGemm::BK1Number>(
|
||||
b_grid_desc_bk0_n_bk1),
|
||||
ds_grid_desc_m_n,
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -309,7 +380,8 @@ template <index_t NDimSpatial,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeDataType = AComputeDataType>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
bool DirectLoad = false>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -480,6 +552,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
|
||||
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
|
||||
? 4 / sizeof(ADataType)
|
||||
: ABlockTransferSrcScalarPerVector;
|
||||
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
|
||||
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
|
||||
? 4 / sizeof(BDataType)
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
@@ -511,7 +592,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
@@ -519,7 +600,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
@@ -533,7 +614,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BComputeDataType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DoElementwiseBeforeCShuffle>;
|
||||
DoElementwiseBeforeCShuffle,
|
||||
DirectLoad>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
@@ -1376,6 +1458,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
// check device
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_device_name() == "gfx908")
|
||||
{
|
||||
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
|
||||
@@ -1971,8 +2061,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
<< "<"
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
|
||||
Reference in New Issue
Block a user