mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Replace the integer max_seqlen by float scale_p as kernel/pipeline parameter
This commit is contained in:
@@ -91,8 +91,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
ck_tile::index_t max_seqlen;
|
||||
float scale_s; // scaling value exerted on the immediate Q@K result
|
||||
float scale_p; // scaling value exerted on the SiLU result
|
||||
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
};
|
||||
@@ -124,8 +124,8 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t seqlen;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
ck_tile::index_t max_seqlen;
|
||||
float scale_s; // scaling value exerted on the immediate Q@K result
|
||||
float scale_p; // scaling value exerted on the SiLU result
|
||||
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
};
|
||||
@@ -257,11 +257,11 @@ struct HstuAttentionFwdKernel
|
||||
seq_stride_o,
|
||||
num_head,
|
||||
-scale_s,
|
||||
seqlen, // max_seqlen
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
1.0f / static_cast<float>(seqlen), // max_seqlen
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
@@ -404,7 +404,7 @@ struct HstuAttentionFwdKernel
|
||||
-1, // seqlen will be updated by another pointer
|
||||
num_head,
|
||||
-scale_s,
|
||||
max_seqlen,
|
||||
1.0f / static_cast<float>(max_seqlen),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
@@ -850,7 +850,7 @@ struct HstuAttentionFwdKernel
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.max_seqlen,
|
||||
kargs.scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}();
|
||||
|
||||
@@ -137,8 +137,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
index_t max_seqlen,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -569,12 +569,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
});
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x) {
|
||||
x = x * type_convert<GemmAccDataType>(
|
||||
__builtin_amdgcn_rcpf(static_cast<float>(max_seqlen)));
|
||||
},
|
||||
o_acc);
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<GemmAccDataType>(scale_p); },
|
||||
o_acc);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
@@ -591,8 +587,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
int max_seqlen,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -609,7 +605,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
max_seqlen,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user