Let IsTokenPairInsideMask() return bool type

This commit is contained in:
Qianfeng Zhang
2025-10-15 08:50:48 +00:00
parent fdb89d3e2f
commit bbda3f6f1c
2 changed files with 10 additions and 10 deletions

View File

@@ -442,8 +442,8 @@ struct HstuAttentionFwdPipelineQRKSVS
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) *=
static_cast<CompDataType>(mask.IsTokenPairInsideMask(row, col));
if(!mask.IsTokenPairInsideMask(row, col))
pcomp_tile(i_j_idx) = static_cast<CompDataType>(0.0f);
});
});
}

View File

@@ -210,7 +210,7 @@ struct HstuBlockMaskWithLocal
}
};
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
@@ -226,7 +226,7 @@ struct HstuBlockMaskWithLocal
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return 1;
return true;
}
else
{
@@ -245,7 +245,7 @@ struct HstuBlockMaskWithLocal
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) || in_min_full_scope));
return static_cast<int>(res);
return res;
}
else
{
@@ -254,7 +254,7 @@ struct HstuBlockMaskWithLocal
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
return static_cast<int>(res);
return res;
}
};
@@ -383,7 +383,7 @@ struct HstuBlockMaskNoLocal
};
};
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
@@ -399,7 +399,7 @@ struct HstuBlockMaskNoLocal
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return 1;
return true;
}
else
{
@@ -415,13 +415,13 @@ struct HstuBlockMaskNoLocal
{
bool res = ((row_id > col_id) || (row == col));
return static_cast<int>(res);
return res;
}
else
{
bool res = ((row_id != col_id) || (row == col));
return static_cast<int>(res);
return res;
};
};