fix macro for exp2; fix warpgemm a/b in transposedC

This commit is contained in:
carlushuang
2024-03-06 15:59:21 +00:00
parent 0e7df1999f
commit 7df3947819
7 changed files with 34 additions and 34 deletions

View File

@@ -231,7 +231,7 @@ struct FmhaFwdKernel
hdim_q,
hdim_v,
nhead_ratio_qk,
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>),
#else
scale,
@@ -320,7 +320,7 @@ struct FmhaFwdKernel
hdim_q,
hdim_v,
nhead_ratio_qk,
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>),
#else
scale,

View File

@@ -321,7 +321,7 @@ struct BlockFmhaPipelineQRKSVS
{
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
#else
x = scale * x + log2e_v<SaccDataType> *
@@ -333,7 +333,7 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
}
@@ -392,12 +392,12 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
{
@@ -512,7 +512,7 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);

View File

@@ -69,7 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
@@ -364,7 +364,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
#else
x = scale * x + log2e_v<SaccDataType> *
@@ -376,7 +376,7 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
}
@@ -471,12 +471,12 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -499,7 +499,7 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
{
@@ -607,7 +607,7 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);

View File

@@ -305,7 +305,7 @@ struct BlockFmhaPipelineQRKSVSFp8
{
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>((y));
#else
x = scale * x + log2e_v<SaccDataType> * type_convert<SaccDataType>((y));
@@ -316,7 +316,7 @@ struct BlockFmhaPipelineQRKSVSFp8
}
else
{
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
}
@@ -375,12 +375,12 @@ struct BlockFmhaPipelineQRKSVSFp8
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -403,7 +403,7 @@ struct BlockFmhaPipelineQRKSVSFp8
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
{

View File

@@ -312,7 +312,7 @@ struct BlockFmhaPipelineQSKSVS
{
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
#else
x = scale * x + log2e_v<SaccDataType> *
@@ -324,7 +324,7 @@ struct BlockFmhaPipelineQSKSVS
}
else
{
#if !CK_FMHA_FWD_FAST_EXP2
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
}
@@ -383,12 +383,12 @@ struct BlockFmhaPipelineQSKSVS
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -411,7 +411,7 @@ struct BlockFmhaPipelineQSKSVS
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
{
@@ -503,7 +503,7 @@ struct BlockFmhaPipelineQSKSVS
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);

View File

@@ -309,8 +309,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
@@ -320,13 +320,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
constexpr auto I0 = number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
b_vec.template get_as<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
a_vec.template get_as<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
return c_vec;