[CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 (#2808)

* [CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12

* fix gemm_splitk_two_stage

* revert .pre-commit-config.yaml
This commit is contained in:
linqunAMD
2025-09-11 22:27:33 +08:00
committed by GitHub
parent 0b9a638f26
commit 60d3e8f504
22 changed files with 439 additions and 192 deletions

View File

@@ -6,6 +6,7 @@
#include "run_gemm_example_common.hpp"
#include "gemm_splitk_two_stage_invoker.hpp"
template <template <typename PreType, typename WorkspaceType> typename GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
@@ -16,13 +17,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfigTwoStage<ck_tile::half_t, float>,
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, float>,
Invoker,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfigTwoStage<ck_tile::bf16_t, float>,
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, float>,
Invoker,
ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
}
@@ -42,7 +43,11 @@ int main(int argc, char* argv[])
try
{
return !run_gemm_example(arg_parser);
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigTwoStage_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigTwoStage>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -11,6 +11,12 @@ struct GemmConfigTwoStage : public GemmConfigComputeV3<PrecType_>
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
};
template <typename PrecType_, typename WorkspaceType_>
struct GemmConfigTwoStage_Wmma : public GemmConfigComputeV3_WMMA<PrecType_>
{
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
};
struct SplitKTwoStageInvoker
{
template <typename GemmConfig,
@@ -155,8 +161,7 @@ struct SplitKTwoStageInvoker
for(auto d : shape)
total_elements *= d;
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});

View File

@@ -4,6 +4,7 @@
#pragma once
#include <string>
#include <variant>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
@@ -173,7 +174,6 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#if CK_TILE_USE_WMMA
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
@@ -194,7 +194,6 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#endif
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase