Update host/device specifiers

This commit is contained in:
PoYen, Chen
2024-07-24 03:45:19 +00:00
parent 6f95239229
commit 5ea60715ea

View File

@@ -58,7 +58,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetQNumElemsPerRead()
CK_TILE_HOST_DEVICE static constexpr auto GetQNumElemsPerRead()
{
using DataType = typename Problem::QDataType;
@@ -121,7 +121,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetKnewNumElemsPerRead()
CK_TILE_HOST_DEVICE static constexpr auto GetKnewNumElemsPerRead()
{
using DataType = typename Problem::KDataType;
@@ -190,7 +190,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVnewDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeVnewDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
@@ -237,7 +237,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_DEVICE static constexpr auto GetRotaryCosSinTileSize()
CK_TILE_HOST_DEVICE static constexpr auto GetRotaryCosSinTileSize()
{
constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
@@ -245,14 +245,14 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
return make_tuple(number<height>{}, number<Problem::kK0>{});
}
else // Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED
else
{
return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
{
using DataType = std::conditional_t<IsRotaryCosSinForQ,
typename Problem::QDataType,