mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM * Updates * Fixes * rebase * fix * Fix * fixes * support for batched gemm
This commit is contained in:
@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
|
||||
@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
|
||||
@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
|
||||
@@ -13,6 +13,8 @@ namespace ck_tile {
|
||||
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
#if 0
|
||||
// 2d
|
||||
template <typename Problem>
|
||||
@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
index_t smem_size = 0;
|
||||
smem_size += smem_size_a + smem_size_b;
|
||||
constexpr index_t smem_size = smem_size_a + smem_size_b;
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
|
||||
@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
|
||||
@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user