Merge commit '60d3e8f504edd25569811b25b4b876d0a504b3b8' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-11 15:11:42 +00:00
parent 269824c6bb
commit 9541fc3ef3
22 changed files with 439 additions and 192 deletions

View File

@@ -15,7 +15,8 @@
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -27,54 +28,19 @@ template <typename ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -105,7 +71,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
@@ -119,7 +86,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -131,7 +98,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -207,7 +175,11 @@ int main(int argc, char* argv[])
{
try
{
return !run_batched_gemm_example(argc, argv);
#if CK_TILE_USE_WMMA
return !run_batched_gemm_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_batched_gemm_example<GemmConfigV3>(argc, argv);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -15,25 +15,116 @@
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
struct GemmConfigMemory
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <typename DataType>
struct BatchedGemmTypeConfig;

View File

@@ -22,7 +22,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -64,7 +65,8 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
batch_stride_C,
batch_count};
float ave_time = batched_gemm<ADataType,
float ave_time = batched_gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -79,7 +81,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GemmConfig, typename ALayout, typename BLayout, typename CLayout>
int run_batched_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -170,7 +172,8 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
float ave_time = invoke_batched_gemm<ADataType,
float ave_time = invoke_batched_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
@@ -311,6 +314,7 @@ int run_batched_gemm_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig>
int run_batched_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -329,7 +333,7 @@ int run_batched_gemm_example(int argc, char* argv[])
// }
if(a_layout == "R" && b_layout == "C")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
return run_batched_gemm_example_with_layouts<GemmConfig>(argc, argv, Row{}, Col{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")