Ck tile gemm padding dim (#1516)

* Support the N dimension padding

* Finished the padding feature for different dimension of K
This commit is contained in:
Thomas Ning
2024-09-18 11:32:29 -07:00
committed by GitHub
parent e84adec3ba
commit 694c300145
4 changed files with 33 additions and 13 deletions

View File

@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(

View File

@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
};
} // namespace ck_tile