Support Wave32 in CK_TILE - Part 1 (#2594)

* Support wave32/wave64 in CK_TILE - Part 1

* remove blocksize in kernel launch

* fix build error

* fix clang format

* fix clang format 2

* fix clang format 3

* fix fmha build error

* fix fmha build 2

* fix fmha build 3

* fix build error 4

* address review comment

* update change log

* replace KernelBlockSize with kBlockSize

* fix CI fail

* fix clang format

* address review comment and rebase code.

* fix universal test fail

---------

Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
linqunAMD
2025-08-19 01:08:31 +08:00
committed by GitHub
parent 26d3300930
commit 9fcc1ee9fd
113 changed files with 610 additions and 531 deletions

View File

@@ -91,13 +91,13 @@ struct FlatmmKernel
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
@@ -127,7 +127,7 @@ struct FlatmmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)

View File

@@ -237,15 +237,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) / 2;
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) / 4;
return TileShape::WarpTile::at(I2) * scale / 4;
}
}