mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] FA bwd repair (#1502)
* fix fa bwd * revert kernelBlockSize in gemm_kernel.hpp
This commit is contained in:
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
#if 1 // coalesce reading for each blocks
|
||||
constexpr index_t M1 = KernelBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
@@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#else // coalesce reading for each warps
|
||||
constexpr index_t M0 = KernelBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kMPerBlock / (M2 * M0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
@@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
#if 1 // coalesce reading for each blocks
|
||||
constexpr index_t N1 = KernelBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
@@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#else // coalesce reading for each warps
|
||||
constexpr index_t N0 = KernelBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = kNPerBlock / (N2 * N0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
|
||||
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -23,10 +23,10 @@ struct BlockGemmPipelineProblem
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
|
||||
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
|
||||
|
||||
Reference in New Issue
Block a user