mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Fp8 block scale quantization for fmha fwd (#3330)
* add block scale parameters to kernel * add block scale to kernel * add smoke test * format * Revert "format" This reverts commit356c3c9706. * only format my code * format py * fix auto not allowd in function prototype * change instance tttt to ttff * fix structured binding issue * change s_acc elementwise op * async pipeline add block scale * add quantation P using shift exp2 * precompute (m - shift) once per row * change blk scale seqstrt ptr name * fix some name * fix for deduction guide * fix some comments * add P scale to qr_ksvs_pipeline * add comment to idx_identity * change the method of calculating descale block index * unify naming style: use block_scale_ as name prefix * unify naming style * update the CHANGELOG.md * Add FP8 block scale quantization support for FMHA forward kernel --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> [ROCm/composable_kernel commit:dd0b4294af]
This commit is contained in:
@@ -37,6 +37,13 @@ struct scales
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
|
||||
template <typename OtherScale>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const
|
||||
{
|
||||
auto new_scale = lhs_ * other;
|
||||
return scales<std::decay_t<decltype(new_scale)>>(new_scale);
|
||||
}
|
||||
|
||||
private:
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user