Remove f8 pipeline, we should share the same pipeline even in f8

This commit is contained in:
rocking
2024-04-08 09:56:23 +00:00
parent f7d81364f3
commit 5c3fdeb0b8
5 changed files with 30 additions and 94 deletions

View File

@@ -166,13 +166,6 @@ struct FmhaFwdKernel
ck_tile::index_t mask_y, mask_x;
};
struct FmhaFwdFP8Kargs
{
float descale_qk; // q*k
float descale_sv; // s*v
// float * o_amax_ptr;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -188,8 +181,7 @@ struct FmhaFwdKernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kIsFp8, FmhaFwdFP8Kargs, FmhaFwdEmptyKargs<3>>
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -201,8 +193,7 @@ struct FmhaFwdKernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kIsFp8, FmhaFwdFP8Kargs, FmhaFwdEmptyKargs<3>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -251,9 +242,7 @@ struct FmhaFwdKernel
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
float descale_sv)
OAccElementFunction o_acc_element_func)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -288,7 +277,6 @@ struct FmhaFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -312,11 +300,6 @@ struct FmhaFwdKernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kIsFp8)
{
kargs.descale_qk = descale_qk;
kargs.descale_sv = descale_sv;
}
return kargs;
}
@@ -356,9 +339,7 @@ struct FmhaFwdKernel
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
float descale_sv)
OAccElementFunction o_acc_element_func)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -393,7 +374,6 @@ struct FmhaFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -414,11 +394,6 @@ struct FmhaFwdKernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kIsFp8)
{
kargs.descale_qk = descale_qk;
kargs.descale_sv = descale_sv;
}
return kargs;
}
@@ -702,38 +677,22 @@ struct FmhaFwdKernel
}();
auto o_acc_tile = [&]() {
if constexpr(kIsFp8)
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
lse_dram_window,
mask,
kargs.scale,
kargs.descale_qk,
kargs.descale_sv,
smem_ptr);
}
else
{
return FmhaPipeline{}(q_dram_window,
identity{},
k_dram_window,
identity{},
v_dram_window,
identity{},
bias_dram_window,
identity{},
lse_dram_window,
identity{},
identity{},
kargs.p_compute_element_func,
kargs.o_acc_element_func,
mask,
kargs.scale,
smem_ptr);
}
return FmhaPipeline{}(q_dram_window,
identity{},
k_dram_window,
identity{},
v_dram_window,
identity{},
bias_dram_window,
identity{},
lse_dram_window,
identity{},
identity{},
kargs.p_compute_element_func,
kargs.o_acc_element_func,
mask,
kargs.scale,
smem_ptr);
}();
// O DRAM and O DRAM window

View File

@@ -9,6 +9,7 @@
namespace ck_tile {
// deprecated pipeline
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
struct BlockFmhaPipelineQRKSVSFp8