mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Support Wave32 in CK_TILE - Part 1 (#2594)
* Support wave32/wave64 in CK_TILE - Part 1 * remove blocksize in kernel launch * fix build error * fix clang format * fix clang format 2 * fix clang format 3 * fix fmha build error * fix fmha build 2 * fix fmha build 3 * fix build error 4 * address review comment * update change log * replace KernelBlockSize with kBlockSize * fix CI fail * fix clang format * address review comment and rebase code. * fix universal test fail --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -34,6 +34,8 @@ struct BatchedTransposeKernel
|
||||
|
||||
using Type = typename Problem::DataType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
struct BatchedTransposeKargs
|
||||
{
|
||||
const void* p_input;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -20,11 +20,10 @@ struct BatchedTransposeLdsProblem
|
||||
|
||||
static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{});
|
||||
static constexpr index_t kColWarps_ = NumWarps::at(number<1>{});
|
||||
static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_;
|
||||
static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kBlockSize = get_warp_size() * kRowWarps_ * kColWarps_;
|
||||
// warps per block
|
||||
static constexpr index_t kLeadNumWarps = kColWarps_;
|
||||
static constexpr index_t kSecondNumWarps = kRowWarps_;
|
||||
|
||||
Reference in New Issue
Block a user