mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Let IsTokenPairInsideMask() return bool type
This commit is contained in:
@@ -442,8 +442,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
pcomp_tile(i_j_idx) *=
|
||||
static_cast<CompDataType>(mask.IsTokenPairInsideMask(row, col));
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
pcomp_tile(i_j_idx) = static_cast<CompDataType>(0.0f);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ struct HstuBlockMaskWithLocal
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
@@ -226,7 +226,7 @@ struct HstuBlockMaskWithLocal
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return 1;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -245,7 +245,7 @@ struct HstuBlockMaskWithLocal
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
|
||||
return static_cast<int>(res);
|
||||
return res;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -254,7 +254,7 @@ struct HstuBlockMaskWithLocal
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
|
||||
return static_cast<int>(res);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -383,7 +383,7 @@ struct HstuBlockMaskNoLocal
|
||||
};
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
@@ -399,7 +399,7 @@ struct HstuBlockMaskNoLocal
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return 1;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -415,13 +415,13 @@ struct HstuBlockMaskNoLocal
|
||||
{
|
||||
bool res = ((row_id > col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
return res;
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = ((row_id != col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
return res;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user