mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 (#2808)
* [CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 * fix gemm_splitk_two_stage * revert .pre-commit-config.yaml
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -36,9 +37,9 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
@@ -139,7 +140,10 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
|
||||
#include "run_grouped_convolution_bwd_data_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
template <typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_data_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -158,6 +162,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -166,6 +171,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -174,6 +180,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -185,6 +192,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -198,12 +206,12 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::half_t>(
|
||||
return run_grouped_conv_bwd_data_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::bf16_t>(
|
||||
return run_grouped_conv_bwd_data_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
@@ -212,4 +220,11 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Mfma>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -36,9 +37,9 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
@@ -141,7 +142,10 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
template <typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -160,6 +164,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -168,6 +173,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -176,6 +182,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -187,6 +194,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -200,12 +208,12 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::half_t>(
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::bf16_t>(
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
@@ -214,4 +222,11 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -35,9 +36,9 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
@@ -130,7 +131,10 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
|
||||
#include "run_grouped_convolution_fwd_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
template <typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -149,6 +153,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -157,6 +162,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -165,6 +171,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
@@ -176,6 +183,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -189,12 +197,12 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<ck_tile::half_t>(
|
||||
return run_grouped_conv_fwd_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<ck_tile::bf16_t>(
|
||||
return run_grouped_conv_fwd_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
@@ -203,4 +211,11 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_example<GemmWarpConfig_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_example<GemmWarpConfig_Mfma>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -12,6 +12,20 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
struct GemmWarpConfig_Mfma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmWarpConfig_Wmma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
@@ -126,7 +140,3 @@ auto create_args(int argc, char* argv[])
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -15,6 +16,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_data<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -36,6 +38,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
@@ -136,6 +139,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_data<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -15,6 +16,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -36,6 +38,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
@@ -136,6 +139,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -15,6 +16,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_fwd<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -36,6 +38,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
@@ -136,6 +139,7 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_fwd<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
|
||||
Reference in New Issue
Block a user