Hacking ck_tile fmha Dropout facility (#1344)

* Add NullBlockDropout to be used when kHasDropout is false

* Change to BlockDropout::Run() for forward to reduce conditional checkings

* Re-format files

---------

Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
Qianfeng
2024-06-19 10:37:22 +08:00
committed by GitHub
parent 8faec23cb4
commit 1973903f49
4 changed files with 79 additions and 46 deletions

View File

@@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS
static constexpr const char* name = "qr";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
@@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
DropoutType& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -246,7 +248,7 @@ struct BlockFmhaPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
@@ -486,7 +488,7 @@ struct BlockFmhaPipelineQRKSVS
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
@@ -618,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},

View File

@@ -112,6 +112,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
@@ -151,7 +153,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
DropoutType& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -298,7 +300,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
@@ -571,7 +573,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
@@ -728,7 +730,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},