Grouped conv fwd with direct load (#3082)

* Grouped conv fwd with direct load

* fix

* fix

* Add IsSupported check

* Fix

* fix inductor
This commit is contained in:
Bartłomiej Kocot
2025-10-29 09:54:42 +01:00
committed by GitHub
parent 3052d7c9e6
commit 66bae4306c
27 changed files with 2165 additions and 285 deletions

View File

@@ -120,21 +120,56 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault;
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_group_offset + a_n_offset,
karg.p_b_grid + b_group_offset,
p_ds_grid_grp,
karg.p_c_grid + e_group_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
block_2_ctile_map,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_m_n,
c_grid_desc_m_n);
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
#if defined(__gfx950__)
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_group_offset + a_n_offset,
karg.p_b_grid + b_group_offset,
p_ds_grid_grp,
karg.p_c_grid + e_group_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
block_2_ctile_map,
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
a_grid_desc_ak0_m_ak1),
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
GridwiseGemm::BK0Number,
GridwiseGemm::BK1Number>(
b_grid_desc_bk0_n_bk1),
ds_grid_desc_m_n,
c_grid_desc_m_n);
#endif
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_group_offset + a_n_offset,
karg.p_b_grid + b_group_offset,
p_ds_grid_grp,
karg.p_c_grid + e_group_offset + e_n_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
block_2_ctile_map,
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
a_grid_desc_ak0_m_ak1),
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
GridwiseGemm::BK0Number,
GridwiseGemm::BK1Number>(
b_grid_desc_bk0_n_bk1),
ds_grid_desc_m_n,
c_grid_desc_m_n);
}
}
#else
ignore = karg;
@@ -208,22 +243,58 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault;
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_group_offset + a_n_offset,
karg.p_b_grid + b_group_offset,
p_ds_grid_grp,
karg.p_c_grid + e_group_offset + e_n_offset,
p_shared_0,
p_shared_1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
block_2_ctile_map,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_m_n,
c_grid_desc_m_n);
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
#if defined(__gfx950__)
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_group_offset + a_n_offset,
karg.p_b_grid + b_group_offset,
p_ds_grid_grp,
karg.p_c_grid + e_group_offset + e_n_offset,
p_shared_0,
p_shared_1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
block_2_ctile_map,
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
a_grid_desc_ak0_m_ak1),
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
GridwiseGemm::BK0Number,
GridwiseGemm::BK1Number>(
b_grid_desc_bk0_n_bk1),
ds_grid_desc_m_n,
c_grid_desc_m_n);
#endif
}
else
{
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_group_offset + a_n_offset,
karg.p_b_grid + b_group_offset,
p_ds_grid_grp,
karg.p_c_grid + e_group_offset + e_n_offset,
p_shared_0,
p_shared_1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op,
block_2_ctile_map,
GridwiseGemm::template TransformGrid<decltype(a_grid_desc_ak0_m_ak1),
GridwiseGemm::AK0Number,
GridwiseGemm::AK1Number>(
a_grid_desc_ak0_m_ak1),
GridwiseGemm::template TransformGrid<decltype(b_grid_desc_bk0_n_bk1),
GridwiseGemm::BK0Number,
GridwiseGemm::BK1Number>(
b_grid_desc_bk0_n_bk1),
ds_grid_desc_m_n,
c_grid_desc_m_n);
}
}
#else
ignore = karg;
@@ -309,7 +380,8 @@ template <index_t NDimSpatial,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType>
typename BComputeDataType = AComputeDataType,
bool DirectLoad = false>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
@@ -480,6 +552,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)
: ABlockTransferSrcScalarPerVector;
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
// Use appropriate gridwise gemm
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3<
@@ -511,7 +592,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
@@ -519,7 +600,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
@@ -533,7 +614,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BComputeDataType,
ADataType,
BDataType,
DoElementwiseBeforeCShuffle>;
DoElementwiseBeforeCShuffle,
DirectLoad>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -1376,6 +1458,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// check device
if constexpr(DirectLoad)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if(get_device_name() == "gfx908")
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
@@ -1971,8 +2061,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
<< "<"
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
if constexpr(DirectLoad) {
str << "_DirectLoad";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "