Clarifying the using of CK_TILE_HOST and CK_TILE_HOST_DEVICE trying to save compiling time

This commit is contained in:
Qianfeng Zhang
2025-11-03 08:39:43 +00:00
parent e31829384d
commit e40ab20b9e
7 changed files with 18 additions and 60 deletions

View File

@@ -18,7 +18,7 @@ struct NRepetitions2DEpilogue
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename ODramWindowTmp,
typename OAccTile,
@@ -71,7 +71,7 @@ struct MRepetitions2DEpilogue
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename ODramWindowTmp,
typename OAccTile,

View File

@@ -453,7 +453,7 @@ struct HstuAttentionFwdKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}

View File

@@ -104,7 +104,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
@@ -121,7 +121,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
@@ -516,7 +516,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile

View File

@@ -108,7 +108,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
@@ -125,7 +125,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
@@ -507,7 +507,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile

View File

@@ -108,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
@@ -125,7 +125,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
@@ -612,7 +612,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile

View File

@@ -108,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
@@ -125,7 +125,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
@@ -599,7 +599,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile

View File

@@ -69,7 +69,7 @@ struct HstuBlockMaskWithLocal
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
// i_y is the start offset of the current tile along the seqlen_q dimension
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
CK_TILE_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
// handle two special cases first
@@ -339,7 +339,7 @@ struct HstuBlockMaskNoLocal
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
// i_y is the start offset of the current tile along the seqlen_q dimension
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
CK_TILE_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
@@ -368,7 +368,7 @@ struct HstuBlockMaskNoLocal
};
}
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
CK_TILE_HOST_DEVICE bool IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
@@ -406,48 +406,6 @@ struct HstuBlockMaskNoLocal
};
};
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_col_id)
return true;
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = min(row, max_row_id);
col_id = min(col, max_col_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
bool res = ((row_id > col_id) || (row == col));
return res;
}
else
{
bool res = ((row_id != col_id) || (row == col));
return res;
};
};
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
template <index_t TileWidth, index_t TileHeight>
CK_TILE_DEVICE constexpr bool IsFullTileInsideMask(index_t i_tile_top,