Simplify the codes in all host/device IsTokenPairInsideMask() trying to reduce branching

This commit is contained in:
Qianfeng Zhang
2025-06-23 14:13:55 +00:00
parent 63a47d7ec5
commit dc7e62a658

View File

@@ -130,193 +130,99 @@ struct HstuBlockMaskWithLocal
};
}
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = max(row - contextual_seqlen + 1, 0);
int col_id = max(col - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(min_full_attn_seqlen > 0)
{
return (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id > col_id) || (row == col)) &&
(row_id - col_id <= max_attn_len));
};
}
else
{
if(min_full_attn_seqlen > 0)
{
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
};
}
return false;
}
else
{
int row_id = min(row, max_id);
int col_id = min(col, max_id);
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = min(row, max_id);
col_id = min(col, max_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(min_full_attn_seqlen > 0)
{
return (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id > col_id) || (row == col)) &&
(row_id - col_id <= max_attn_len));
};
}
else
{
if(min_full_attn_seqlen > 0)
{
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
}
else
{
return (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
};
}
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
return false;
return (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) || in_min_full_scope));
}
else
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
}
};
CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col)
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = max(row - contextual_seqlen + 1, 0);
int col_id = max(col - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return 1;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res);
}
else
{
bool res =
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
return static_cast<int>(res);
};
}
else
{
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res);
}
else
{
bool res = (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
return static_cast<int>(res);
};
}
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = min(row, max_id);
int col_id = min(col, max_id);
row_id = min(row, max_id);
col_id = min(col, max_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(kUseCausal)
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
return static_cast<int>(res);
}
else
{
bool res =
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
bool res = (((row_id > col_id) || (row == col)) &&
((row_id - col_id <= max_attn_len) || in_min_full_scope));
return static_cast<int>(res);
};
}
else
{
if(min_full_attn_seqlen > 0)
{
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) ||
(row_id >= max_id - min_full_attn_seqlen)));
return static_cast<int>(res);
}
else
{
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
return static_cast<int>(res);
}
else
{
bool res = (((row_id != col_id) || (row == col)) &&
(abs(row_id - col_id) <= max_attn_len));
bool res = (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
return static_cast<int>(res);
};
}
return static_cast<int>(res);
}
};
@@ -397,104 +303,84 @@ struct HstuBlockMaskNoLocal
};
}
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = max(row - contextual_seqlen + 1, 0);
int col_id = max(col - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return true;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
return (row_id > col_id) || (row == col);
}
else
{
return (row_id != col_id) || (row == col);
};
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = min(row, max_id);
int col_id = min(col, max_id);
row_id = min(row, max_id);
col_id = min(col, max_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
return (row_id > col_id) || (row == col);
}
else
{
return (row_id != col_id) || (row == col);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
return (row_id > col_id) || (row == col);
}
else
{
return (row_id != col_id) || (row == col);
};
};
CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col)
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
{
int row_id;
int col_id;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = max(row - contextual_seqlen + 1, 0);
int col_id = max(col - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, 0);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_id);
col_id = min(col_id, max_id);
if(row_id == 0 && col_id < max_id)
return 1;
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
bool res = ((row_id > col_id) || (row == col));
return static_cast<int>(res);
}
else
{
bool res = ((row_id != col_id) || (row == col));
return static_cast<int>(res);
};
}
else
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
int row_id = min(row, max_id);
int col_id = min(col, max_id);
row_id = min(row, max_id);
col_id = min(col, max_id);
};
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
bool res = ((row_id > col_id) || (row == col));
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
// diagonal line are always considerred
if constexpr(IsMasking)
{
bool res = ((row_id > col_id) || (row == col));
return static_cast<int>(res);
}
else
{
bool res = ((row_id != col_id) || (row == col));
return static_cast<int>(res);
};
return static_cast<int>(res);
}
else
{
bool res = ((row_id != col_id) || (row == col));
return static_cast<int>(res);
};
};
// if the whole tile inside the masking area, no need for pixel-by-pixel checking