[CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 (#2808)

* [CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12

* fix gemm_splitk_two_stage

* revert .pre-commit-config.yaml
This commit is contained in:
linqunAMD
2025-09-11 22:27:33 +08:00
committed by GitHub
parent 0b9a638f26
commit 60d3e8f504
22 changed files with 439 additions and 192 deletions

View File

@@ -13,6 +13,7 @@
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -36,9 +37,9 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -139,7 +140,10 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
#include "run_grouped_convolution_bwd_data_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_data_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
@@ -158,6 +162,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -166,6 +171,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -174,6 +180,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -185,6 +192,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
}
}
template <typename GemmWarpConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -198,12 +206,12 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::half_t>(
return run_grouped_conv_bwd_data_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::bf16_t>(
return run_grouped_conv_bwd_data_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
@@ -212,4 +220,11 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); }
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Wmma>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Mfma>(argc, argv);
#endif
}

View File

@@ -13,6 +13,7 @@
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -36,9 +37,9 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -141,7 +142,10 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
#include "run_grouped_convolution_bwd_weight_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_weight_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
@@ -160,6 +164,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -168,6 +173,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -176,6 +182,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -187,6 +194,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
}
}
template <typename GemmWarpConfig>
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -200,12 +208,12 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::half_t>(
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::bf16_t>(
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
@@ -214,4 +222,11 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv);
#else
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv);
#endif
}

View File

@@ -13,6 +13,7 @@
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -35,9 +36,9 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -130,7 +131,10 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
#include "run_grouped_convolution_fwd_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_fwd_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
@@ -149,6 +153,7 @@ int run_grouped_conv_fwd_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -157,6 +162,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -165,6 +171,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
@@ -176,6 +183,7 @@ int run_grouped_conv_fwd_example_prec_type(
}
}
template <typename GemmWarpConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -189,12 +197,12 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<ck_tile::half_t>(
return run_grouped_conv_fwd_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<ck_tile::bf16_t>(
return run_grouped_conv_fwd_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
@@ -203,4 +211,11 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<GemmWarpConfig_Wmma>(argc, argv);
#else
return !run_grouped_conv_fwd_example<GemmWarpConfig_Mfma>(argc, argv);
#endif
}

View File

@@ -12,6 +12,20 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
const ck_tile::index_t kbatch,
@@ -126,7 +140,3 @@ auto create_args(int argc, char* argv[])
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
const ck_tile::stream_config& s);

View File

@@ -3,6 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -15,6 +16,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_repeat)
{
float ave_time = grouped_conv_bwd_data<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -36,6 +38,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
}
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
@@ -136,6 +139,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,

View File

@@ -3,6 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -15,6 +16,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
int n_repeat)
{
float ave_time = grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -36,6 +38,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
}
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
@@ -136,6 +139,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,

View File

@@ -3,6 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -15,6 +16,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
int n_repeat)
{
float ave_time = grouped_conv_fwd<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -36,6 +38,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
}
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
@@ -136,6 +139,7 @@ int run_grouped_conv_fwd_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,