mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Fixes and updates
This commit is contained in:
@@ -32,6 +32,22 @@ extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, h
|
||||
extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
|
||||
{
|
||||
std::ofstream outFile(fileName, std::ios::binary);
|
||||
if(outFile)
|
||||
{
|
||||
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
|
||||
outFile.close();
|
||||
printf("Wrote output to file %s\n", fileName);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Could not open file %s for writing\n", fileName);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
@@ -424,6 +440,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_dev.FromDevice(o_host.data());
|
||||
|
||||
dumpBufferToFile("output_dev.dat", o_host.data(), o_host.get_element_space_size());
|
||||
dumpBufferToFile("output_host.dat", o_host_ref.data(), o_host.get_element_space_size());
|
||||
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
res = ck_tile::check_err(
|
||||
|
||||
@@ -118,8 +118,8 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
param.batch_stride_bias,
|
||||
param.batch_stride_o,
|
||||
param.num_targets_ptr,
|
||||
param.window_size,
|
||||
param.contextual_seqlen,
|
||||
param.window_size,
|
||||
param.min_full_attn_seqlen,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
|
||||
@@ -383,7 +383,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
if(i_loop < num_loops - 1)
|
||||
if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len && i_loop < num_loops - 1)
|
||||
return false;
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
|
||||
@@ -107,8 +107,8 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
param.num_targets_ptr,
|
||||
param.window_size,
|
||||
param.contextual_seqlen,
|
||||
param.window_size,
|
||||
param.min_full_attn_seqlen,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
|
||||
@@ -29,12 +29,13 @@ struct HstuBlockMasking
|
||||
|
||||
max_uih_len = seqlen_;
|
||||
|
||||
max_uih_len -= contextual_seqlen - 1;
|
||||
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
|
||||
max_uih_len -= num_target;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
@@ -45,7 +46,7 @@ struct HstuBlockMasking
|
||||
}
|
||||
else
|
||||
{
|
||||
if(contextual_seqlen > 0 && (i_y < contextual_seqlen))
|
||||
if(i_y < contextual_seqlen)
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
|
||||
if constexpr(kUseCausal && !kUseLocal)
|
||||
@@ -101,10 +102,10 @@ struct HstuBlockMasking
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = std::abs(row - col) <= max_attn_len;
|
||||
result = abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seqlen);
|
||||
result = (row >= max_uih_len - min_full_attn_seqlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -33,6 +33,9 @@ template <typename InOutDataType,
|
||||
bool kUseLocal>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
using HstuMask = HstuBlockMasking<kUseCausal, kUseLocal>;
|
||||
static constexpr bool kHasMask = kUseCausal || kUseLocal;
|
||||
|
||||
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
|
||||
@@ -100,8 +103,13 @@ struct reference_hstu_attention
|
||||
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
HstuBlockMasking<kUseCausal, kUseLocal> mask{
|
||||
max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target};
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{
|
||||
max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target};
|
||||
else
|
||||
return HstuMask{0, contextual_seqlen, 0, seqlen, num_target};
|
||||
}();
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
|
||||
Reference in New Issue
Block a user