merge from dteng_flatmm_opt

This commit is contained in:
lalala-sh
2025-07-16 10:12:19 +00:00
parent 3499fe67ff
commit fb76450e63
9 changed files with 1324 additions and 267 deletions

View File

@@ -146,10 +146,14 @@ struct FlatmmKernel
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
{
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
{
return FlatmmPipeline::GetSmemSize();
}
struct SplitKBatchOffset
{
@@ -560,7 +564,8 @@ struct FlatmmKernel
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -580,15 +585,16 @@ struct FlatmmKernel
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr);
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}
@@ -607,7 +613,8 @@ struct FlatmmKernel
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
@@ -618,7 +625,8 @@ struct FlatmmKernel
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,