mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Simplify the codes in all host/device IsTokenPairInsideMask() trying to reduce branching
This commit is contained in:
@@ -130,193 +130,99 @@ struct HstuBlockMaskWithLocal
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = max(row - contextual_seqlen + 1, 0);
|
||||
int col_id = max(col - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
(row_id - col_id <= max_attn_len));
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
};
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
int row_id = min(row, max_id);
|
||||
int col_id = min(col, max_id);
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
(row_id - col_id <= max_attn_len));
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
};
|
||||
}
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return false;
|
||||
return (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
}
|
||||
else
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = max(row - contextual_seqlen + 1, 0);
|
||||
int col_id = max(col - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return 1;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res =
|
||||
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = min(row, max_id);
|
||||
int col_id = min(col, max_id);
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res =
|
||||
(((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len));
|
||||
bool res = (((row_id > col_id) || (row == col)) &&
|
||||
((row_id - col_id <= max_attn_len) || in_min_full_scope));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(min_full_attn_seqlen > 0)
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) ||
|
||||
(row_id >= max_id - min_full_attn_seqlen)));
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
(abs(row_id - col_id) <= max_attn_len));
|
||||
bool res = (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -397,104 +303,84 @@ struct HstuBlockMaskNoLocal
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = max(row - contextual_seqlen + 1, 0);
|
||||
int col_id = max(col - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return true;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = min(row, max_id);
|
||||
int col_id = min(col, max_id);
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
};
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
return (row_id > col_id) || (row == col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
};
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = max(row - contextual_seqlen + 1, 0);
|
||||
int col_id = max(col - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_id);
|
||||
col_id = min(col_id, max_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_id)
|
||||
return 1;
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
bool res = ((row_id > col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = ((row_id != col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
int row_id = min(row, max_id);
|
||||
int col_id = min(col, max_id);
|
||||
row_id = min(row, max_id);
|
||||
col_id = min(col, max_id);
|
||||
};
|
||||
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
bool res = ((row_id > col_id) || (row == col));
|
||||
// use row_id/col_id to check the dist between two q/k token pair, token pairs on the
|
||||
// diagonal line are always considerred
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
bool res = ((row_id > col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = ((row_id != col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
return static_cast<int>(res);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool res = ((row_id != col_id) || (row == col));
|
||||
|
||||
return static_cast<int>(res);
|
||||
};
|
||||
};
|
||||
|
||||
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
|
||||
|
||||
Reference in New Issue
Block a user