mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Support Wave32 in CK_TILE - Part 1 (#2594)
* Support wave32/wave64 in CK_TILE - Part 1 * remove blocksize in kernel launch * fix build error * fix clang format * fix clang format 2 * fix clang format 3 * fix fmha build error * fix fmha build 2 * fix fmha build 3 * fix build error 4 * address review comment * update change log * replace KernelBlockSize with kBlockSize * fix CI fail * fix clang format * address review comment and rebase code. * fix universal test fail --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -91,13 +91,13 @@ struct FlatmmKernel
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -127,7 +127,7 @@ struct FlatmmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
|
||||
Reference in New Issue
Block a user