Simulate TF32 with BF16x3 (#3142)

* tf32:bf16x3:use bf16x3 emulate tf32 gemm

* change blockwiseGemm to demo bf16x3

* temp push

* self review

* self review

* fix multi-device compile error

* bug fix

* code refactor

* limit to gfx950

* enhance gemm gfx942 threshold

* lower change from blockwise to warpwise

* refact codes

* refact codes

* error fix

* change threshold

* bug fix

* fix threshold error

* change host reference implement to same as device

* bug fix

* bug fix

* code refact

* fix clang-format fail

* code refine
This commit is contained in:
yinglu
2025-11-14 08:21:09 +08:00
committed by GitHub
parent f2cfc6b94e
commit 2a73eb3bc0
16 changed files with 419 additions and 49 deletions

View File

@@ -94,7 +94,8 @@ template <typename ALayout,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();

View File

@@ -134,7 +134,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
@@ -145,7 +146,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceGroupedGemm_Xdl;
GET_NXDL_PER_WAVE_IMPL
@@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = ADataType;
// GridwiseGemm
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<