mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
CK Tile FA Training kernels (#1286)
* FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 <danyao12> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -141,6 +141,36 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
@@ -160,14 +190,14 @@ struct GenericAttentionMask
|
||||
}
|
||||
else
|
||||
{
|
||||
return i_x >= x_end;
|
||||
return i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
@@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
@@ -283,13 +343,13 @@ struct SimplifiedGenericAttentionMask
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
@@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user