Fix in kernel and forward dispatch for jagged mode

This commit is contained in:
Qianfeng Zhang
2025-04-08 16:37:52 +00:00
parent dc2f72a09f
commit 561d490990
2 changed files with 2 additions and 2 deletions

View File

@@ -518,7 +518,7 @@ struct HstuAttentionFwdKernel
}
batch_offset_o = query_start * kargs.seq_stride_o;
kargs.seqlen = kargs.seq_offsets_ptr[1] - kargs.seq_offsets_ptr[0];
kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier

View File

@@ -116,7 +116,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
}();
dim3 kGridSize =
HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v);
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;