something khushbu can help with

This commit is contained in:
AviralGoelAMD
2025-07-16 16:30:44 +00:00
parent c1badfd30c
commit c1c30b1c18
3 changed files with 1455 additions and 4 deletions

View File

@@ -63,7 +63,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
@@ -90,7 +90,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
ck_tile::FlatmmPipelineAGmemBGmemCRegV2<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,

View File

@@ -590,6 +590,45 @@ struct FlatmmKernel
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr);
}
}
CK_TILE_DEVICE static void RunFlatmm2(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
// Create empty D tensors
constexpr auto empty_ds_dram_windows = ck_tile::make_tuple();
// Call with empty D tensors
EpiloguePipeline{}
.template operator()<decltype(c_block_window),
decltype(c_block_tile),
decltype(empty_ds_dram_windows)>(
c_block_window, c_block_tile, empty_ds_dram_windows, smem_ptr_ping);
}
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
@@ -608,10 +647,11 @@ struct FlatmmKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr_pong[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
is_any_of<EDataType, fp16_t, bf16_t>::value) && FlatmmPipeline::DoubleSmemBuffer == false)
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<scheduler_type>(a_ptr,
@@ -624,6 +664,19 @@ struct FlatmmKernel
i_m,
i_n);
}
else
{
RunFlatmm2(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
};