mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
fix macro for exp2; fix warpgemm a/b in transposedC
This commit is contained in:
@@ -30,9 +30,9 @@ set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero -v --save-temps -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0 -v --save-temps -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
|
||||
@@ -231,7 +231,7 @@ struct FmhaFwdKernel
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead_ratio_qk,
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale,
|
||||
@@ -320,7 +320,7 @@ struct FmhaFwdKernel
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead_ratio_qk,
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale,
|
||||
|
||||
@@ -321,7 +321,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
@@ -333,7 +333,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
@@ -392,12 +392,12 @@ struct BlockFmhaPipelineQRKSVS
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -512,7 +512,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
|
||||
@@ -69,7 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
@@ -364,7 +364,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
@@ -376,7 +376,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
@@ -471,12 +471,12 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -499,7 +499,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -607,7 +607,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
|
||||
|
||||
@@ -305,7 +305,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>((y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> * type_convert<SaccDataType>((y));
|
||||
@@ -316,7 +316,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
@@ -375,12 +375,12 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -403,7 +403,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
|
||||
@@ -312,7 +312,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
@@ -324,7 +324,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_FMHA_FWD_FAST_EXP2
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
@@ -383,12 +383,12 @@ struct BlockFmhaPipelineQSKSVS
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -411,7 +411,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -503,7 +503,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_FMHA_FWD_FAST_EXP2
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
|
||||
@@ -309,8 +309,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -320,13 +320,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
|
||||
Reference in New Issue
Block a user