mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 (#2808)
* [CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 * fix gemm_splitk_two_stage * revert .pre-commit-config.yaml
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "run_gemm_example_common.hpp"
|
||||
#include "gemm_splitk_two_stage_invoker.hpp"
|
||||
|
||||
template <template <typename PreType, typename WorkspaceType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
@@ -16,13 +17,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfigTwoStage<ck_tile::half_t, float>,
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, float>,
|
||||
Invoker,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfigTwoStage<ck_tile::bf16_t, float>,
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, float>,
|
||||
Invoker,
|
||||
ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
@@ -42,7 +43,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
return !run_gemm_example(arg_parser);
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigTwoStage_Wmma>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigTwoStage>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -11,6 +11,12 @@ struct GemmConfigTwoStage : public GemmConfigComputeV3<PrecType_>
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
|
||||
};
|
||||
|
||||
template <typename PrecType_, typename WorkspaceType_>
|
||||
struct GemmConfigTwoStage_Wmma : public GemmConfigComputeV3_WMMA<PrecType_>
|
||||
{
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
|
||||
};
|
||||
|
||||
struct SplitKTwoStageInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
@@ -155,8 +161,7 @@ struct SplitKTwoStageInvoker
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
@@ -173,7 +174,6 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
@@ -194,7 +194,6 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
Reference in New Issue
Block a user