mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Remove f8 pipeline, we should share the same pipeline even in f8
This commit is contained in:
@@ -166,13 +166,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_y, mask_x;
|
||||
};
|
||||
|
||||
struct FmhaFwdFP8Kargs
|
||||
{
|
||||
float descale_qk; // q*k
|
||||
float descale_sv; // s*v
|
||||
// float * o_amax_ptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -188,8 +181,7 @@ struct FmhaFwdKernel
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kIsFp8, FmhaFwdFP8Kargs, FmhaFwdEmptyKargs<3>>
|
||||
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -201,8 +193,7 @@ struct FmhaFwdKernel
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kIsFp8, FmhaFwdFP8Kargs, FmhaFwdEmptyKargs<3>>
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -251,9 +242,7 @@ struct FmhaFwdKernel
|
||||
// LSEElementFunction lse_element_func,
|
||||
// SAccElementFunction s_acc_element_func,
|
||||
PComputeElementFunction p_compute_element_func,
|
||||
OAccElementFunction o_acc_element_func,
|
||||
float descale_qk,
|
||||
float descale_sv)
|
||||
OAccElementFunction o_acc_element_func)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -288,7 +277,6 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8 args
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -312,11 +300,6 @@ struct FmhaFwdKernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kIsFp8)
|
||||
{
|
||||
kargs.descale_qk = descale_qk;
|
||||
kargs.descale_sv = descale_sv;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -356,9 +339,7 @@ struct FmhaFwdKernel
|
||||
// LSEElementFunction lse_element_func,
|
||||
// SAccElementFunction s_acc_element_func,
|
||||
PComputeElementFunction p_compute_element_func,
|
||||
OAccElementFunction o_acc_element_func,
|
||||
float descale_qk,
|
||||
float descale_sv)
|
||||
OAccElementFunction o_acc_element_func)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -393,7 +374,6 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8 args
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -414,11 +394,6 @@ struct FmhaFwdKernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(kIsFp8)
|
||||
{
|
||||
kargs.descale_qk = descale_qk;
|
||||
kargs.descale_sv = descale_sv;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -702,38 +677,22 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kIsFp8)
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale,
|
||||
kargs.descale_qk,
|
||||
kargs.descale_sv,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{},
|
||||
k_dram_window,
|
||||
identity{},
|
||||
v_dram_window,
|
||||
identity{},
|
||||
bias_dram_window,
|
||||
identity{},
|
||||
lse_dram_window,
|
||||
identity{},
|
||||
identity{},
|
||||
kargs.p_compute_element_func,
|
||||
kargs.o_acc_element_func,
|
||||
mask,
|
||||
kargs.scale,
|
||||
smem_ptr);
|
||||
}
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{},
|
||||
k_dram_window,
|
||||
identity{},
|
||||
v_dram_window,
|
||||
identity{},
|
||||
bias_dram_window,
|
||||
identity{},
|
||||
lse_dram_window,
|
||||
identity{},
|
||||
identity{},
|
||||
kargs.p_compute_element_func,
|
||||
kargs.o_acc_element_func,
|
||||
mask,
|
||||
kargs.scale,
|
||||
smem_ptr);
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// deprecated pipeline
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSFp8
|
||||
|
||||
Reference in New Issue
Block a user