[CK_TILE] FA bwd repair (#1502)

* fix fa bwd

* revert kernelBlockSize in gemm_kernel.hpp
This commit is contained in:
Dan Yao
2024-09-11 01:45:32 +08:00
committed by GitHub
parent cf08df6b5e
commit d09572e8c2
6 changed files with 28 additions and 28 deletions

View File

@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;

View File

@@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
@@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t M1 = KernelBlockSize / get_warp_size();
constexpr index_t M1 = kBlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = kMPerBlock / (M2 * M1);
@@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t M0 = KernelBlockSize / get_warp_size();
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M1 = kMPerBlock / (M2 * M0);
return make_static_tile_distribution(
@@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
@@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = KernelBlockSize / get_warp_size();
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
@@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t N0 = KernelBlockSize / get_warp_size();
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = kNPerBlock / (N2 * N0);
return make_static_tile_distribution(

View File

@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;

View File

@@ -23,10 +23,10 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;