Replace the integer max_seqlen by float scale_p as kernel/pipeline parameter

This commit is contained in:
Qianfeng Zhang
2025-08-01 07:49:57 +00:00
parent de71d3359c
commit 7c9032d2cf
2 changed files with 18 additions and 22 deletions

View File

@@ -91,8 +91,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t seq_stride_o;
ck_tile::index_t num_head;
float scale_s;
ck_tile::index_t max_seqlen;
float scale_s; // scaling value exerted on the immediate Q@K result
float scale_p; // scaling value exerted on the SiLU result
ck_tile::index_t contextual_seqlen;
};
@@ -124,8 +124,8 @@ struct HstuAttentionFwdKernel
ck_tile::index_t seqlen;
ck_tile::index_t num_head;
float scale_s;
ck_tile::index_t max_seqlen;
float scale_s; // scaling value exerted on the immediate Q@K result
float scale_p; // scaling value exerted on the SiLU result
ck_tile::index_t contextual_seqlen;
};
@@ -257,11 +257,11 @@ struct HstuAttentionFwdKernel
seq_stride_o,
num_head,
-scale_s,
seqlen, // max_seqlen
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
{}, // placeholder for dropout
1.0f / static_cast<float>(seqlen), // max_seqlen
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
{}, // placeholder for dropout
};
if constexpr(kHasLocalMask)
@@ -404,7 +404,7 @@ struct HstuAttentionFwdKernel
-1, // seqlen will be updated by another pointer
num_head,
-scale_s,
max_seqlen,
1.0f / static_cast<float>(max_seqlen),
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
@@ -850,7 +850,7 @@ struct HstuAttentionFwdKernel
bias_dram_window,
mask,
kargs.scale_s,
kargs.max_seqlen,
kargs.scale_p,
smem_ptr,
dropout);
}();

View File

@@ -137,8 +137,8 @@ struct HstuAttentionFwdPipelineQRKSVS
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
HstuMask mask,
float scale_s,
index_t max_seqlen,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLu result
void* smem_ptr,
DropoutType& dropout) const
{
@@ -569,12 +569,8 @@ struct HstuAttentionFwdPipelineQRKSVS
});
} while(seqlen_k_curr < seqlen_k_end);
tile_elementwise_inout(
[&](auto& x) {
x = x * type_convert<GemmAccDataType>(
__builtin_amdgcn_rcpf(static_cast<float>(max_seqlen)));
},
o_acc);
tile_elementwise_inout([&](auto& x) { x = x * type_convert<GemmAccDataType>(scale_p); },
o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
@@ -591,8 +587,8 @@ struct HstuAttentionFwdPipelineQRKSVS
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,
float scale_s,
int max_seqlen,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
void* smem_ptr,
DropoutType& dropout) const
{
@@ -609,7 +605,7 @@ struct HstuAttentionFwdPipelineQRKSVS
identity{},
mask,
scale_s,
max_seqlen,
scale_p,
smem_ptr,
dropout);
}