mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Support Wave32 in CK_TILE - Part 1 (#2594)
* Support wave32/wave64 in CK_TILE - Part 1 * remove blocksize in kernel launch * fix build error * fix clang format * fix clang format 2 * fix clang format 3 * fix fmha build error * fix fmha build 2 * fix fmha build 3 * fix build error 4 * address review comment * update change log * replace KernelBlockSize with kBlockSize * fix CI fail * fix clang format * address review comment and rebase code. * fix universal test fail --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -15,9 +15,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
template <int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
@@ -35,15 +35,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
//
|
||||
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
|
||||
//
|
||||
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, typename KernelImpl, typename... Args>
|
||||
CK_TILE_HOST auto
|
||||
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
|
||||
|
||||
const auto kernel = kentry<MinBlockPerCu, KernelImpl, Args...>;
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user