Update the seqlen_k_curr inside the first gemm loop

This commit is contained in:
Qianfeng Zhang
2025-04-25 13:59:48 +00:00
parent 7818cce1c3
commit 4a49119d98

View File

@@ -356,7 +356,7 @@ struct HstuAttentionFwdPipelineQRKSVS
set_tile_if(
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
}
@@ -368,14 +368,15 @@ struct HstuAttentionFwdPipelineQRKSVS
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
}
}
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
seqlen_k_curr += kK1;
});
// load one k_tile for next iteration
@@ -415,11 +416,9 @@ struct HstuAttentionFwdPipelineQRKSVS
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tiles[I0], null_randval_window);
randval_lds_ptr, seqlen_k_curr - kN0, pcomp_tiles[I0], null_randval_window);
}
seqlen_k_curr += kK1;
auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
@@ -467,13 +466,11 @@ struct HstuAttentionFwdPipelineQRKSVS
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr,
seqlen_k_curr,
seqlen_k_curr - kN0 + (i_k1 + 1) * kK1,
pcomp_tiles[number<i_k1 + 1>{}],
null_randval_window);
}
seqlen_k_curr += kK1;
p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(tile_elementwise_in(