mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] FA bwd repair (#1502)
* fix fa bwd * revert kernelBlockSize in gemm_kernel.hpp
This commit is contained in:
@@ -29,9 +29,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
|
||||
Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kK0>,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
@@ -62,9 +62,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kVHeaddim,
|
||||
Problem::BlockFmhaShape::BlockTile::kK1>,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
@@ -94,9 +94,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
|
||||
Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kK2>,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK2>,
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
@@ -127,9 +127,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::BlockTile::kK3>,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
|
||||
@@ -159,9 +159,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
|
||||
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::BlockTile::kK4>,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user