mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Update HstuBlockMaskWithLocal::GetTileRangeAlongX, add comments and test cases for causal == false
This commit is contained in:
@@ -58,82 +58,106 @@ struct HstuBlockMaskWithLocal
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// handle two special cases first
|
||||
if(!is_tile_in_first_split)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
|
||||
if constexpr(!kUseCausal)
|
||||
{
|
||||
if(i_y >= contextual_seqlen)
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
// tile is partitially or completely in [max_uih_len-min_full_attn_seqlen,
|
||||
// max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
{
|
||||
index_t x_start = max(0, i_y - max_attn_len);
|
||||
return ck_tile::make_tuple(0, seqlen);
|
||||
}
|
||||
else // tile is completely inside [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
// is_tile_in_first_split is true, either min_full_attn_seqlen is 0 or tile is
|
||||
// in [0, max_uih_len-min_full_attn_seqlen)
|
||||
if constexpr(!kUseCausal)
|
||||
{
|
||||
if(i_y >= contextual_seqlen + max_attn_len)
|
||||
{
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
index_t x_start_aligned = x_start - x_start % XTile;
|
||||
|
||||
// some rows of the tile in [max_uih_len -max_attn_len, max_uih_len)
|
||||
if(i_y + YTile > max_uih_len - max_attn_len)
|
||||
{
|
||||
return ck_tile::make_tuple(x_start_aligned, seqlen);
|
||||
}
|
||||
else
|
||||
else // whole tile in [contextual_seqlen+max_attn_len, max_uih_len
|
||||
// -max_attn_len)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile + max_attn_len, seqlen);
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
return ck_tile::make_tuple(x_start_aligned, x_end);
|
||||
};
|
||||
}
|
||||
else
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
index_t x_end = seqlen;
|
||||
index_t x_start = max_uih_len - max_attn_len;
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // for i_y < contextual_seqlen + max_attn_len
|
||||
{
|
||||
if(i_y + YTile > max_uih_len)
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
index_t x_end = min(max(i_y + YTile + max_attn_len, max_uih_len), seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = max(i_y + YTile + max_attn_len, max_uih_len);
|
||||
index_t x_end = min(i_y + YTile + max_attn_len, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
else // kUseCausal && kUseLocal
|
||||
{
|
||||
if(i_y >= contextual_seqlen)
|
||||
if(i_y >= contextual_seqlen + max_attn_len)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_uih_len)
|
||||
if(i_y < max_uih_len)
|
||||
{
|
||||
index_t x_start = max(0, i_y - max_attn_len);
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
else
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = max_uih_len - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // for i_y < contextual_seqlen + max_attn_len
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
|
||||
if(i_y + YTile > max_uih_len)
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile, max_uih_len), seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
};
|
||||
index_t x_end = min(i_y + YTile, seqlen);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -50,5 +50,7 @@ for dtype in "fp16" "bf16"; do
|
||||
## jagged causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale
|
||||
|
||||
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale
|
||||
set +x
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user