mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Fix in kernel and forward dispatch for jagged mode
This commit is contained in:
@@ -518,7 +518,7 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
batch_offset_o = query_start * kargs.seq_stride_o;
|
||||
|
||||
kargs.seqlen = kargs.seq_offsets_ptr[1] - kargs.seq_offsets_ptr[0];
|
||||
kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
|
||||
@@ -116,7 +116,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v);
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user