mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
remove FmhaBwdTilePartitioner
This commit is contained in:
@@ -104,8 +104,7 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdKTilePartitioner<{F_bn0}>,
|
||||
fmha_bwd_pipeline_{F_idx},
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
|
||||
@@ -517,8 +516,7 @@ using fmha_bwd_dot_do_o_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdQTilePartitioner</* BlockSize = */ 64>,
|
||||
fmha_bwd_dot_do_o_{F_idx}>;
|
||||
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
|
||||
|
||||
using dot_do_o_trait_{F_idx} =
|
||||
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
@@ -641,8 +639,7 @@ using fmha_bwd_convert_dq_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<ck_tile::FmhaBwdQTilePartitioner<{F_bm0}>,
|
||||
fmha_bwd_convert_dq_{F_idx}>;
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
|
||||
|
||||
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
|
||||
|
||||
@@ -23,13 +23,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename FmhaPipeline_,
|
||||
typename KGradEpiloguePipeline_,
|
||||
typename VGradEpiloguePipeline_>
|
||||
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
|
||||
struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
|
||||
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
|
||||
@@ -536,7 +532,17 @@ struct FmhaBwdDQDKDVKernel
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_);
|
||||
return dim3(
|
||||
batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
{
|
||||
const index_t i_block = blockIdx.z;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.x;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -554,7 +560,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k);
|
||||
const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
|
||||
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
|
||||
|
||||
@@ -1037,10 +1043,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaBwdOGradDotO_>
|
||||
template <typename FmhaBwdOGradDotO_>
|
||||
struct FmhaBwdOGradDotOKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
|
||||
@@ -1189,7 +1194,16 @@ struct FmhaBwdOGradDotOKernel
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -1199,7 +1213,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q);
|
||||
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
|
||||
|
||||
@@ -1286,10 +1300,9 @@ struct FmhaBwdOGradDotOKernel
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaBwdConvertQGrad_>
|
||||
template <typename FmhaBwdConvertQGrad_>
|
||||
struct FmhaBwdConvertQGradKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
|
||||
@@ -1439,7 +1452,16 @@ struct FmhaBwdConvertQGradKernel
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -1449,7 +1471,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q);
|
||||
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
|
||||
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t kN0>
|
||||
struct FmhaBwdKTilePartitioner
|
||||
{
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, kN0));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.z;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.x;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
template <ck_tile::index_t kM0>
|
||||
struct FmhaBwdQTilePartitioner
|
||||
{
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user