Merge commit '60d3e8f504edd25569811b25b4b876d0a504b3b8' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-11 15:11:42 +00:00
parent 269824c6bb
commit 9541fc3ef3
22 changed files with 439 additions and 192 deletions

View File

@@ -6,6 +6,7 @@
#include "run_gemm_example_common.hpp" #include "run_gemm_example_common.hpp"
#include "gemm_splitk_two_stage_invoker.hpp" #include "gemm_splitk_two_stage_invoker.hpp"
template <template <typename PreType, typename WorkspaceType> typename GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser) int run_gemm_example(ck_tile::ArgParser& arg_parser)
{ {
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
@@ -16,13 +17,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run_gemm_example_prec_type<GemmConfigTwoStage<ck_tile::half_t, float>, return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, float>,
Invoker, Invoker,
ck_tile::half_t>(a_layout, b_layout, arg_parser); ck_tile::half_t>(a_layout, b_layout, arg_parser);
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run_gemm_example_prec_type<GemmConfigTwoStage<ck_tile::bf16_t, float>, return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, float>,
Invoker, Invoker,
ck_tile::bf16_t>(a_layout, b_layout, arg_parser); ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
} }
@@ -42,7 +43,11 @@ int main(int argc, char* argv[])
try try
{ {
return !run_gemm_example(arg_parser); #if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigTwoStage_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigTwoStage>(arg_parser);
#endif
} }
catch(const std::runtime_error& e) catch(const std::runtime_error& e)
{ {

View File

@@ -11,6 +11,12 @@ struct GemmConfigTwoStage : public GemmConfigComputeV3<PrecType_>
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>; using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
}; };
template <typename PrecType_, typename WorkspaceType_>
struct GemmConfigTwoStage_Wmma : public GemmConfigComputeV3_WMMA<PrecType_>
{
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
};
struct SplitKTwoStageInvoker struct SplitKTwoStageInvoker
{ {
template <typename GemmConfig, template <typename GemmConfig,
@@ -155,8 +161,7 @@ struct SplitKTwoStageInvoker
for(auto d : shape) for(auto d : shape)
total_elements *= d; total_elements *= d;
constexpr ck_tile::index_t kBlockSize = const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});

View File

@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <variant>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
@@ -173,7 +174,6 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2; static constexpr int kBlockPerCu = 2;
}; };
#if CK_TILE_USE_WMMA
template <typename PrecType> template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{ {
@@ -194,7 +194,6 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
static constexpr int kBlockPerCu = 2; static constexpr int kBlockPerCu = 2;
}; };
#endif
template <typename PrecType> template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase struct GemmConfigComputeV4 : public GemmConfigBase

View File

@@ -15,7 +15,8 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "batched_gemm.hpp" #include "batched_gemm.hpp"
template <typename ADataType, template <typename GemmConfig,
typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
typename AccDataType, typename AccDataType,
@@ -27,54 +28,19 @@ template <typename ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough> typename CDEElementWise = ck_tile::element_wise::PassThrough>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
// Memory friendly for Interwave scheduler constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4; constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = 1; constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false; constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
@@ -105,7 +71,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>; using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
@@ -119,7 +86,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value; constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType, using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -131,7 +98,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>; using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType, ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -207,7 +175,11 @@ int main(int argc, char* argv[])
{ {
try try
{ {
return !run_batched_gemm_example(argc, argv); #if CK_TILE_USE_WMMA
return !run_batched_gemm_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_batched_gemm_example<GemmConfigV3>(argc, argv);
#endif
} }
catch(const std::runtime_error& e) catch(const std::runtime_error& e)
{ {

View File

@@ -15,25 +15,116 @@
#define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT struct GemmConfigMemory
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 {
#endif // Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) static constexpr ck_tile::index_t M_Warp = 4;
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem static constexpr ck_tile::index_t N_Warp = 1;
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem static constexpr ck_tile::index_t K_Warp = 1;
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) static constexpr ck_tile::index_t M_Warp_Tile = 32;
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 static constexpr ck_tile::index_t N_Warp_Tile = 32;
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 static constexpr ck_tile::index_t K_Warp_Tile = 8;
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) static constexpr bool DoubleSmemBuffer = false;
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave };
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value" struct GemmConfigV3
#endif {
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <typename DataType> template <typename DataType>
struct BatchedGemmTypeConfig; struct BatchedGemmTypeConfig;

View File

@@ -22,7 +22,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
} }
template <typename ADataType, template <typename GemmConfig,
typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
typename AccDataType, typename AccDataType,
@@ -64,7 +65,8 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
batch_stride_C, batch_stride_C,
batch_count}; batch_count};
float ave_time = batched_gemm<ADataType, float ave_time = batched_gemm<GemmConfig,
ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
AccDataType, AccDataType,
@@ -79,7 +81,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time; return ave_time;
} }
template <typename ALayout, typename BLayout, typename CLayout> template <typename GemmConfig, typename ALayout, typename BLayout, typename CLayout>
int run_batched_gemm_example_with_layouts(int argc, int run_batched_gemm_example_with_layouts(int argc,
char* argv[], char* argv[],
const ALayout a_layout = ALayout{}, const ALayout a_layout = ALayout{},
@@ -170,7 +172,8 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
float ave_time = invoke_batched_gemm<ADataType, float ave_time = invoke_batched_gemm<GemmConfig,
ADataType,
BDataType, BDataType,
ck_tile::tuple<>, ck_tile::tuple<>,
AccDataType, AccDataType,
@@ -311,6 +314,7 @@ int run_batched_gemm_example_with_layouts(int argc,
return pass; return pass;
} }
template <typename GemmConfig>
int run_batched_gemm_example(int argc, char* argv[]) int run_batched_gemm_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
@@ -329,7 +333,7 @@ int run_batched_gemm_example(int argc, char* argv[])
// } // }
if(a_layout == "R" && b_layout == "C") if(a_layout == "R" && b_layout == "C")
{ {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_batched_gemm_example_with_layouts<GemmConfig>(argc, argv, Row{}, Col{}, Row{});
} }
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C") // work else if(a_layout == "C" && b_layout == "C")

View File

@@ -353,5 +353,9 @@ int run_grouped_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv); return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
#endif
} }

View File

@@ -17,10 +17,6 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4 #define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
template <typename PrecType, ck_tile::index_t M_Warp_Tile> template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile() constexpr ck_tile::index_t get_k_warp_tile()
{ {
@@ -190,6 +186,29 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool kPadK = true; static constexpr bool kPadK = true;
}; };
template <typename PrecType>
struct GemmConfigComputeV4_Wmma : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType> template <typename PrecType>
struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
{ {

View File

@@ -17,7 +17,8 @@
#include "gemm_multi_d_fp16.hpp" #include "gemm_multi_d_fp16.hpp"
#include "utils.hpp" #include "utils.hpp"
template <typename ADataType, template <typename GemmConfig,
typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
typename AccDataType, typename AccDataType,
@@ -29,58 +30,22 @@ template <typename ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough> typename CDEElementWise = ck_tile::element_wise::PassThrough>
auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
// Memory friendly for Interwave scheduler constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4; constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = 1; constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false; constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
#endif constexpr bool kPadM = false;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) constexpr bool kPadN = false;
// Compute friendly for Intrawave scheduler constexpr bool kPadK = false;
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false; constexpr bool TransposeC = false;
@@ -109,7 +74,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>; using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
@@ -123,7 +89,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value; constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType, using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -135,7 +101,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>; using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType, ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -203,4 +170,11 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
#include "run_gemm_multi_d_fp16_example.inc" #include "run_gemm_multi_d_fp16_example.inc"
int main(int argc, char* argv[]) { return !run_multiple_d_gemm_example(argc, argv); } int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_multiple_d_gemm_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_multiple_d_gemm_example<GemmConfigV3>(argc, argv);
#endif
}

View File

@@ -13,26 +13,6 @@
#define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t; using D0DataType = ck_tile::half_t;
@@ -41,6 +21,117 @@ using EDataType = ck_tile::half_t;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>; using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float; using AccDataType = float;
struct GemmConfigMemory
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
@@ -68,7 +159,8 @@ auto create_args(int argc, char* argv[])
using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs<DsDataType::size()>; using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs<DsDataType::size()>;
template <typename ADataType, template <typename GemmConfig,
typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
typename AccDataType, typename AccDataType,

View File

@@ -5,7 +5,8 @@
#include <cstddef> #include <cstddef>
#include "ck_tile/utility/json_dump.hpp" #include "ck_tile/utility/json_dump.hpp"
template <typename ADataType, template <typename GemmConfig,
typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
typename AccDataType, typename AccDataType,
@@ -43,7 +44,8 @@ float invoke_gemm_multi_d(const void* a_m_k_dev_buf,
StrideDs, StrideDs,
StrideE}); StrideE});
float ave_time = gemm_multi_d<ADataType, float ave_time = gemm_multi_d<GemmConfig,
ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
AccDataType, AccDataType,
@@ -58,7 +60,8 @@ float invoke_gemm_multi_d(const void* a_m_k_dev_buf,
return ave_time; return ave_time;
} }
template <typename ALayout, template <typename GemmConfig,
typename ALayout,
typename BLayout, typename BLayout,
typename D0Layout, typename D0Layout,
typename D1Layout, typename D1Layout,
@@ -136,7 +139,8 @@ int run_multiple_d_gemm_example_with_layouts(int argc,
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1}; std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
float ave_time = invoke_gemm_multi_d<ADataType, float ave_time = invoke_gemm_multi_d<GemmConfig,
ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
AccDataType, AccDataType,
@@ -239,6 +243,7 @@ int run_multiple_d_gemm_example_with_layouts(int argc,
return pass; return pass;
} }
template <typename GemmConfig>
int run_multiple_d_gemm_example(int argc, char* argv[]) int run_multiple_d_gemm_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
@@ -256,7 +261,7 @@ int run_multiple_d_gemm_example(int argc, char* argv[])
if(a_layout == "R" && b_layout == "C" && ds_layout == "R") if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
{ {
return run_multiple_d_gemm_example_with_layouts( return run_multiple_d_gemm_example_with_layouts<GemmConfig>(
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
} }
else else

View File

@@ -13,6 +13,7 @@
#include "grouped_convolution_utils.hpp" #include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename AccDataType, typename AccDataType,
@@ -36,9 +37,9 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8; constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8; constexpr ck_tile::index_t VectorSizeB = 8;
@@ -139,7 +140,10 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
#include "run_grouped_convolution_bwd_data_example.inc" #include "run_grouped_convolution_bwd_data_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType> template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_data_example_prec_type( int run_grouped_conv_bwd_data_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{ {
@@ -158,6 +162,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{ {
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{}, return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -166,6 +171,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{ {
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{}, return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -174,6 +180,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{ {
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{}, return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -185,6 +192,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
} }
} }
template <typename GemmWarpConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[]) int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
@@ -198,12 +206,12 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::half_t>( return run_grouped_conv_bwd_data_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv); in_layout, wei_layout, out_layout, argc, argv);
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::bf16_t>( return run_grouped_conv_bwd_data_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv); in_layout, wei_layout, out_layout, argc, argv);
} }
else else
@@ -212,4 +220,11 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
} }
} }
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); } int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Wmma>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Mfma>(argc, argv);
#endif
}

View File

@@ -13,6 +13,7 @@
#include "grouped_convolution_utils.hpp" #include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename AccDataType, typename AccDataType,
@@ -36,9 +37,9 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8; constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8; constexpr ck_tile::index_t VectorSizeB = 8;
@@ -141,7 +142,10 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
#include "run_grouped_convolution_bwd_weight_example.inc" #include "run_grouped_convolution_bwd_weight_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType> template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_weight_example_prec_type( int run_grouped_conv_bwd_weight_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{ {
@@ -160,6 +164,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{ {
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{}, return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -168,6 +173,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{ {
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{}, return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -176,6 +182,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{ {
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{}, return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -187,6 +194,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(
} }
} }
template <typename GemmWarpConfig>
int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
@@ -200,12 +208,12 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::half_t>( return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv); in_layout, wei_layout, out_layout, argc, argv);
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::bf16_t>( return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv); in_layout, wei_layout, out_layout, argc, argv);
} }
else else
@@ -214,4 +222,11 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
} }
} }
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv);
#else
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv);
#endif
}

View File

@@ -13,6 +13,7 @@
#include "grouped_convolution_utils.hpp" #include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename AccDataType, typename AccDataType,
@@ -35,9 +36,9 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8; constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8; constexpr ck_tile::index_t VectorSizeB = 8;
@@ -130,7 +131,10 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
#include "run_grouped_convolution_fwd_example.inc" #include "run_grouped_convolution_fwd_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType> template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_fwd_example_prec_type( int run_grouped_conv_fwd_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{ {
@@ -149,6 +153,7 @@ int run_grouped_conv_fwd_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{ {
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{}, return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -157,6 +162,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{ {
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{}, return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -165,6 +171,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC") else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC")
{ {
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{}, return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType, InPrecType,
WeiPrecType, WeiPrecType,
OutPrecType>( OutPrecType>(
@@ -176,6 +183,7 @@ int run_grouped_conv_fwd_example_prec_type(
} }
} }
template <typename GemmWarpConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[]) int run_grouped_conv_fwd_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
@@ -189,12 +197,12 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run_grouped_conv_fwd_example_prec_type<ck_tile::half_t>( return run_grouped_conv_fwd_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv); in_layout, wei_layout, out_layout, argc, argv);
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run_grouped_conv_fwd_example_prec_type<ck_tile::bf16_t>( return run_grouped_conv_fwd_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv); in_layout, wei_layout, out_layout, argc, argv);
} }
else else
@@ -203,4 +211,11 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
} }
} }
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); } int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<GemmWarpConfig_Wmma>(argc, argv);
#else
return !run_grouped_conv_fwd_example<GemmWarpConfig_Mfma>(argc, argv);
#endif
}

View File

@@ -12,6 +12,20 @@
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp" #include "ck_tile/ops/grouped_convolution.hpp"
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType> template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK, auto calculate_rtol_atol(const ck_tile::index_t GemmK,
const ck_tile::index_t kbatch, const ck_tile::index_t kbatch,
@@ -126,7 +140,3 @@ auto create_args(int argc, char* argv[])
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
// host API
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
const ck_tile::stream_config& s);

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename AccDataType, typename AccDataType,
@@ -15,6 +16,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_repeat) int n_repeat)
{ {
float ave_time = grouped_conv_bwd_data<NDimSpatial, float ave_time = grouped_conv_bwd_data<NDimSpatial,
GemmWarpConfig,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,
@@ -36,6 +38,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
} }
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType = InDataType, typename WeiDataType = InDataType,
typename OutDataType = InDataType, typename OutDataType = InDataType,
@@ -136,6 +139,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial, invoke_grouped_conv_bwd_data<NDimSpatial,
GemmWarpConfig,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename AccDataType, typename AccDataType,
@@ -15,6 +16,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
int n_repeat) int n_repeat)
{ {
float ave_time = grouped_conv_bwd_weight<NDimSpatial, float ave_time = grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,
@@ -36,6 +38,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
} }
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType = InDataType, typename WeiDataType = InDataType,
typename OutDataType = InDataType, typename OutDataType = InDataType,
@@ -136,6 +139,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_weight<NDimSpatial, invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename AccDataType, typename AccDataType,
@@ -15,6 +16,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
int n_repeat) int n_repeat)
{ {
float ave_time = grouped_conv_fwd<NDimSpatial, float ave_time = grouped_conv_fwd<NDimSpatial,
GemmWarpConfig,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,
@@ -36,6 +38,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
} }
template <ck_tile::index_t NDimSpatial, template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType, typename InDataType,
typename WeiDataType = InDataType, typename WeiDataType = InDataType,
typename OutDataType = InDataType, typename OutDataType = InDataType,
@@ -136,6 +139,7 @@ int run_grouped_conv_fwd_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd<NDimSpatial, invoke_grouped_conv_fwd<NDimSpatial,
GemmWarpConfig,
InDataType, InDataType,
WeiDataType, WeiDataType,
AccDataType, AccDataType,

View File

@@ -25,6 +25,7 @@ struct ElementWiseKernel
{ {
return is_wave32() ? kBlockSize / 2 : kBlockSize; return is_wave32() ? kBlockSize / 2 : kBlockSize;
} }
template <typename... XDataType, typename Dims> template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(const Dims lens, CK_TILE_DEVICE void operator()(const Dims lens,
const Dims input_strides, const Dims input_strides,

View File

@@ -529,7 +529,10 @@ struct GroupedConvolutionBackwardDataKernel
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch); return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdDataKernelArgsSpecialized CK_TILE_HOST static constexpr GroupedConvBwdDataKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdDataHostArgs& hostArgs) MakeKernelArgs(const GroupedConvBwdDataHostArgs& hostArgs)

View File

@@ -392,7 +392,10 @@ struct GroupedConvolutionBackwardWeightKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs) MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)

View File

@@ -398,7 +398,10 @@ struct GroupedConvolutionForwardKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)