mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] FA bwd kernels optimization (#1397)
* tmp save
* fix batch deterministic bugs
* fix group deterministic bugs
* codegen update
* reorder files
* bias support
* hd256 bias support
* bwd smoke test update
* simplify convert dq
* fix hd256 dropout scratch
* do{}while() -> while(){}
* comments
* remove FmhaBwdTilePartitioner
* save clear_tile
* refactor dropout
* code cleanup
* code cleanup
* comments
* fix epilogue problem
* fix fwd dropout
* group convert_dq opt
* fix dq alignment
* Do not store storerandval in bwd for flash attention integration
* fix hd32 error and boost performance
* revert
* Remove duplicated WarpGemm definitions in the policy file
* dropout patch for mrepeat 16*16
* code sync up
* dq_acc stride
* dq_acc stride stuff
* codegen update
* fwd dropout revert
* fix hd128 scratches and boost performance
* receipt 3 for simplified smoke test
* more strides for fa integration
* fix hd64 scratches and boost performance
* non-iglp pipeline for headdim padding cases
* dpad same as dvpad for flash attention integration
* unpadded lse&d for group mode
* Support unpad layout for group lse
* Support unpad lse layout for splitkv
* Fix stride for splitkv kernel
* fix unpadded lse issue in fwd splitkv
* comment
* solve lds read&write conflicts
* rename
* bias rename
* tile index revert
---------
Co-authored-by: danyao12 <danyao12>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
|
||||
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
|
||||
|
||||
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a>(a_vec)
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b>(b_vec)
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ template <typename AType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
@@ -58,8 +65,15 @@ template <typename AType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC>
|
||||
using WarpGemmMfmaDispatcher = typename impl::
|
||||
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type;
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user