Tiny update in GetMaxVectorSize()

This commit is contained in:
Qianfeng Zhang
2025-12-14 04:26:30 +00:00
parent f79a29ac80
commit 1ab5e9da93

View File

@@ -60,10 +60,11 @@ struct HstuAttentionFwdPipelineProblem
{
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 6 == 0)
// return 6;
if constexpr(ElemPerThread % 8 == 0)
return 8;
else if constexpr(ElemPerThread % 6 == 0)
return 6;
else if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
@@ -72,10 +73,11 @@ struct HstuAttentionFwdPipelineProblem
}
else if constexpr(std::is_same_v<DataType, float>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 3 == 0)
// return 3;
if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 3 == 0)
return 3;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;