mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix hd32 error and boost performance
This commit is contained in:
@@ -448,7 +448,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# '32' : [FmhaBwdDQDKDVTileSize( 64, 64, 32, 64, 32, 64, 64, 32, 32, 1, 2, 1, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
"kr_ktr_vr"],
|
||||
|
||||
@@ -660,7 +660,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
}();
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer();
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem,
|
||||
decltype(pt_reg_tensor),
|
||||
decltype(pt_gemm)>(pt_reg_tensor, pt_gemm);
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
@@ -732,7 +734,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer();
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(dst_gemm)>(dst_reg_tensor, dst_gemm);
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
@@ -908,8 +912,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
}
|
||||
}();
|
||||
|
||||
pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer();
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(pt_gemm)>(
|
||||
pt_reg_tensor, pt_gemm);
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
@@ -965,7 +970,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto dst_gemm = cast_tile<GemmDataType>(dst);
|
||||
|
||||
dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer();
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(dst_gemm)>(dst_reg_tensor, dst_gemm);
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
store_tile(ds_lds_window, dst_gemm);
|
||||
|
||||
@@ -1508,6 +1508,116 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return ds_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem, typename PTOutTensor, typename PTInTensor>
|
||||
CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out,
|
||||
const PTInTensor& pt_in)
|
||||
{
|
||||
if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16)
|
||||
{
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
auto pt_warp_tensor =
|
||||
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
pt_warp_tensor.get_thread_buffer() = pt_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
pt_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
pt_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
pt_out.get_thread_buffer() = pt_in.get_thread_buffer();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, typename SGradTOutTensor, typename SGradTInTensor>
|
||||
CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out,
|
||||
const SGradTInTensor& dst_in)
|
||||
{
|
||||
if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16)
|
||||
{
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
auto dst_warp_tensor =
|
||||
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
dst_warp_tensor.get_thread_buffer() = dst_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dst_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
dst_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_out.get_thread_buffer() = dst_in.get_thread_buffer();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user