From 7df39478193d4fd4eb84ce5c287c0f7452f67a1f Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 6 Mar 2024 15:59:21 +0000 Subject: [PATCH] fix macro for exp2; fix warpgemm a/b in transposedC --- example/ck_tile/01_fmha/CMakeLists.txt | 4 ++-- .../ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 4 ++-- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 12 ++++++------ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 14 +++++++------- .../pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp | 10 +++++----- .../fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 12 ++++++------ .../ops/gemm/warp/warp_gemm_attribute_mfma.hpp | 12 ++++++------ 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 2fded9f0e1..f5434e7016 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -30,9 +30,9 @@ set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero -v --save-temps -Wno-gnu-line-marker) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0 -v --save-temps -Wno-gnu-line-marker) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() # Allow comparing floating points directly in order to check sentinel values diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 7385058688..1fe6415453 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -231,7 +231,7 @@ struct FmhaFwdKernel hdim_q, hdim_v, nhead_ratio_qk, -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale * ck_tile::log2e_v<>), #else scale, @@ -320,7 +320,7 @@ struct FmhaFwdKernel hdim_q, hdim_v, nhead_ratio_qk, -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale * ck_tile::log2e_v<>), #else scale, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index cbfee530e6..c2953fc2ea 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -321,7 +321,7 @@ struct BlockFmhaPipelineQRKSVS { tile_elementwise_inout( [&](auto& x, const auto& y) { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 x = scale * x + type_convert(bias_element_func(y)); #else x = scale * x + log2e_v * @@ -333,7 +333,7 @@ struct BlockFmhaPipelineQRKSVS } else { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif } @@ -392,12 +392,12 @@ struct BlockFmhaPipelineQRKSVS constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVS constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { if constexpr(kHasBias) { @@ -512,7 +512,7 @@ struct BlockFmhaPipelineQRKSVS constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 921fb5f2b8..76c20bfe46 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -69,7 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; #endif @@ -364,7 +364,7 @@ struct BlockFmhaPipelineQRKSVSAsync { tile_elementwise_inout( [&](auto& x, const auto& y) { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 x = scale * x + type_convert(bias_element_func(y)); #else x = scale * x + log2e_v * @@ -376,7 +376,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif } @@ -471,12 +471,12 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -499,7 +499,7 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { if constexpr(kHasBias) { @@ -607,7 +607,7 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 87180c4ed4..0643e7c0d9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -305,7 +305,7 @@ struct BlockFmhaPipelineQRKSVSFp8 { tile_elementwise_inout( [&](auto& x, const auto& y) { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 x = scale * x + type_convert((y)); #else x = scale * x + log2e_v * type_convert((y)); @@ -316,7 +316,7 @@ struct BlockFmhaPipelineQRKSVSFp8 } else { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif } @@ -375,12 +375,12 @@ struct BlockFmhaPipelineQRKSVSFp8 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -403,7 +403,7 @@ struct BlockFmhaPipelineQRKSVSFp8 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { if constexpr(kHasBias) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index a4f687bcee..e7fa19449b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -312,7 +312,7 @@ struct BlockFmhaPipelineQSKSVS { tile_elementwise_inout( [&](auto& x, const auto& y) { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 x = scale * x + type_convert(bias_element_func(y)); #else x = scale * x + log2e_v * @@ -324,7 +324,7 @@ struct BlockFmhaPipelineQSKSVS } else { -#if !CK_FMHA_FWD_FAST_EXP2 +#if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif } @@ -383,12 +383,12 @@ struct BlockFmhaPipelineQSKSVS constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -411,7 +411,7 @@ struct BlockFmhaPipelineQSKSVS constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { if constexpr(kHasBias) { @@ -503,7 +503,7 @@ struct BlockFmhaPipelineQSKSVS constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_FMHA_FWD_FAST_EXP2 +#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(kHasBias) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 79ed1b6e15..420870e61d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -309,8 +309,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution // swap A and B, value and type static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - a_vec.template get_as()[iKIter], - b_vec.template get_as()[iKIter]); + b_vec.template get_as()[iKIter], + a_vec.template get_as()[iKIter]); }); } @@ -320,13 +320,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution constexpr auto I0 = number<0>{}; // swap A and B, value and type - auto c_vec = Impl{}(a_vec.template get_as()[I0], - b_vec.template get_as()[I0]); + auto c_vec = Impl{}(b_vec.template get_as()[I0], + a_vec.template get_as()[I0]); static_for<1, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - a_vec.template get_as()[iKIter], - b_vec.template get_as()[iKIter]); + b_vec.template get_as()[iKIter], + a_vec.template get_as()[iKIter]); }); return c_vec;