Flatmm merge (#2168)

* sync with function interface of cshuffleepiloge,fix flatmm build fail

* move code from solin/flatmm which add mfma16*16*32fp8 and optimize flatmm

---------

Co-authored-by: solin <bingzhou@amd.com>
This commit is contained in:
BingYuan.Zhou
2025-05-08 12:59:57 +08:00
committed by GitHub
parent c7b8e86e34
commit 6a3960c1e1
11 changed files with 552 additions and 192 deletions

View File

@@ -321,7 +321,7 @@ struct FlatmmKernel
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
return make_naive_tensor_view<address_space_enum::global>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
@@ -330,7 +330,7 @@ struct FlatmmKernel
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
return make_naive_tensor_view<address_space_enum::global>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
@@ -426,7 +426,6 @@ struct FlatmmKernel
return make_tuple(a_block_window, b_flat_block_window, c_block_window);
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
CDataType* c_ptr,
@@ -438,7 +437,8 @@ struct FlatmmKernel
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset);
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -453,9 +453,8 @@ struct FlatmmKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr);
}
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
@@ -475,21 +474,12 @@ struct FlatmmKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.k_batch == 1)
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunFlatmm<memory_operation_enum::atomic_add>(
a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
}
}
};