now can build

This commit is contained in:
carlushuang
2024-03-04 20:45:51 +00:00
parent 112d521b09
commit a67473fff8
55 changed files with 829 additions and 534 deletions

View File

@@ -575,9 +575,8 @@ struct FmhaFwdKernel
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables is not supported. Remove
/// following copy capture of the 'i_nhead'
/// if compiled in C++20
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -189,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -208,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -346,12 +347,15 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -360,7 +364,7 @@ struct BlockFmhaPipelineQRKSVS
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -375,7 +379,7 @@ struct BlockFmhaPipelineQRKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -231,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
@@ -251,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -389,12 +390,15 @@ struct BlockFmhaPipelineQRKSVSAsync
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -403,7 +407,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -454,7 +458,7 @@ struct BlockFmhaPipelineQRKSVSAsync
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -181,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -329,12 +330,15 @@ struct BlockFmhaPipelineQRKSVSFp8
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -343,7 +347,7 @@ struct BlockFmhaPipelineQRKSVSFp8
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -358,7 +362,7 @@ struct BlockFmhaPipelineQRKSVSFp8
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
@@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -NumericLimits<SMPLComputeDataType>::Infinity());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -338,12 +338,15 @@ struct BlockFmhaPipelineQSKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -NumericLimits<SMPLComputeDataType>::Infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -352,7 +355,7 @@ struct BlockFmhaPipelineQSKSVS
s,
sequence<1>{},
f_max,
-NumericLimits<SMPLComputeDataType>::Infinity()); // m_local = rowmax(S{j})
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -367,7 +370,7 @@ struct BlockFmhaPipelineQSKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -NumericLimits<SMPLComputeDataType>::Infinity()
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -9,6 +9,11 @@
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
@@ -97,9 +102,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
@@ -222,9 +226,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
}
@@ -918,12 +921,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8)
{
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {