Support Wave32 in CK_TILE - Part 1 (#2594)

* Support wave32/wave64 in CK_TILE - Part 1

* remove blocksize in kernel launch

* fix build error

* fix clang format

* fix clang format 2

* fix clang format 3

* fix fmha build error

* fix fmha build 2

* fix fmha build 3

* fix build error 4

* address review comment

* update change log

* replace KernelBlockSize with kBlockSize

* fix CI fail

* fix clang format

* address review comment and rebase code.

* fix universal test fail

---------

Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
linqunAMD
2025-08-19 01:08:31 +08:00
committed by GitHub
parent 26d3300930
commit 9fcc1ee9fd
113 changed files with 610 additions and 531 deletions

View File

@@ -64,6 +64,7 @@ struct BatchedGemmKernel
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
@@ -121,9 +122,16 @@ struct BatchedGemmKernel
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
CK_TILE_HOST static auto BlockSize() -> dim3
{
return dim3(UniversalGemmKernel::KernelBlockSize);
if(ck_tile::is_wave32())
{
return dim3(UniversalGemmKernel::kBlockSize / 2);
}
else
{
return dim3(UniversalGemmKernel::kBlockSize);
}
}
CK_TILE_HOST static constexpr BatchedGemmKernelArgs

View File

@@ -113,6 +113,7 @@ struct GemmKernel
static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
CK_TILE_HOST static auto GetName() -> const std::string
{

View File

@@ -86,6 +86,7 @@ struct GemmKernelMultiD
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;

View File

@@ -128,7 +128,7 @@ struct GroupedGemmKernel
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -155,7 +155,7 @@ struct GroupedGemmKernel
return group_count * sizeof(GemmTransKernelArg);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
@@ -166,10 +166,10 @@ struct GroupedGemmKernel
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>;
int occupancy;
HIP_CHECK_ERROR(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}

View File

@@ -196,7 +196,7 @@ struct UniversalGemmKernel
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
struct has_persistent_kernel
@@ -275,15 +275,26 @@ struct UniversalGemmKernel
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static auto BlockSize()
{
if(ck_tile::is_wave32())
{
return dim3(kBlockSize / 2);
}
else
{
return dim3(kBlockSize);
}
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs)
@@ -371,7 +382,9 @@ struct UniversalGemmKernel
}
}
bool AsTesnorIsValid = {true};
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
: GemmPipeline::template GetVectorSizeA<false>();
bool AsTesnorIsValid = {true};
static_for<0, NumATensor, 1>{}([&](auto index) {
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
@@ -387,7 +400,7 @@ struct UniversalGemmKernel
}
AsTesnorIsValid = false;
}
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
if(kargs.K % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -407,7 +420,7 @@ struct UniversalGemmKernel
}
AsTesnorIsValid = false;
}
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
if(kargs.M % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -418,7 +431,9 @@ struct UniversalGemmKernel
}
});
bool BsTesnorIsValid = {true};
bool BsTesnorIsValid = {true};
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
: GemmPipeline::template GetVectorSizeB<false>();
static_for<0, NumBTensor, 1>{}([&](auto index) {
using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
@@ -432,7 +447,7 @@ struct UniversalGemmKernel
}
BsTesnorIsValid = false;
}
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
if(kargs.N % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -454,7 +469,7 @@ struct UniversalGemmKernel
}
BsTesnorIsValid = false;
}
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
if(kargs.K % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{