mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Clarifying the using of CK_TILE_HOST and CK_TILE_HOST_DEVICE trying to save compiling time
This commit is contained in:
@@ -18,7 +18,7 @@ struct NRepetitions2DEpilogue
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
@@ -71,7 +71,7 @@ struct MRepetitions2DEpilogue
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
|
||||
@@ -453,7 +453,7 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
@@ -121,7 +121,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
@@ -516,7 +516,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
|
||||
@@ -108,7 +108,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
@@ -125,7 +125,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
@@ -507,7 +507,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
|
||||
@@ -108,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
@@ -125,7 +125,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
@@ -612,7 +612,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
|
||||
@@ -108,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
@@ -125,7 +125,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
@@ -599,7 +599,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
|
||||
@@ -69,7 +69,7 @@ struct HstuBlockMaskWithLocal
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// handle two special cases first
|
||||
@@ -339,7 +339,7 @@ struct HstuBlockMaskNoLocal
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
@@ -368,7 +368,7 @@ struct HstuBlockMaskNoLocal
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_HOST_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
@@ -406,48 +406,6 @@ struct HstuBlockMaskNoLocal
|
||||
};
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_row_id);
|
||||
col_id = min(col, max_col_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
bool res = ((row_id > col_id) || (row == col));
|
||||
|
||||
return res;
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = ((row_id != col_id) || (row == col));
|
||||
|
||||
return res;
|
||||
};
|
||||
};
|
||||
|
||||
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
|
||||
template <index_t TileWidth, index_t TileHeight>
|
||||
CK_TILE_DEVICE constexpr bool IsFullTileInsideMask(index_t i_tile_top,
|
||||
|
||||
Reference in New Issue
Block a user