mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
* Optimize GEMM on MI200/300: 1. Add new blockwise gemm pipeline 2. Add irregular splitk intances * clang format + typo fix * Fix a bug * initial commit * Add more instances to irregular splitk * blkgemm pipeline v1~4 prototype * Sanity Checked. Known issue: 1. Poor performance of splitk 2. Register spill on blkgemmpipeline v3 * Sanity and Performance fix: 1. fix a bug related to sanity in grouped b2c mapping 2. fix a bug related to sanity and performance in splitk offset * Sanity and API update: 1. Remove prefetch stage 2. Fix valid check bug 3, Add first gemm_universal instance into ckProfiler * Add NN instances for gemm universal * 1. Add NT instances for gemm_universal 2. Fix a bug about Kpadding in gemm_universal * Fix a bug regarding padding Odd K number * remove kernel print * Fix KPadding bug... * Update safety check * another try to fix kpadding.. * Sanity checked * new instances.. * clang format+typo fix * remove clang format script's change * Add non-hotloop compile option * 1. Add fp16xfp8 example 2. pull packed convert f8 from pr1150 * Some miscs.. opt and fix * Add pipeline description docs * Split universal gemm instance library to cut profiler compiling time * uncomment cmakefile * Fix a bug caused by blockwise_gemm_pipe_v2 * reduce default splitk to 1 * Add 224x256x64 tile size * update, including: 1. Experiment pipeline 5~7 2. Optimization for pipeline 4 3. Organized instance library * temp save * temp save * Permuted lds layout, sanity and function checked * clang format * Move OOB check from RunRead to RunWrite, for better software pipeline. TODO: agpr spill when NN layout * clangformat * A/B splitpipe scheduler for v3 * Fix two bugs * bug fix * fix a bug in oob check * Example for mixed fp16_fp8 gemm * Clean experimental code blocks * Add mixed precision gemm into profiler * tempsave * optimize m/n major lds layout * Add RRR GEMM mixed precision instances * Optimize f8 matrix transpose * Add test_gemm_universal * A/B spilt schedule for blkpip v5 * Take ds_read2 into iglp scheduling scheme * format * fixed cmake * Add llvm-option into CI cmake flag --------- Co-authored-by: Jing Zhang <jizhan@amd.com>
105 lines
3.1 KiB
C++
105 lines
3.1 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck/utility/common_header.hpp"
|
|
#include "ck/tensor_description/tensor_adaptor.hpp"
|
|
|
|
namespace ck {
|
|
|
|
enum struct BlockGemmPipelineScheduler
|
|
{
|
|
Intrawave,
|
|
Interwave,
|
|
};
|
|
|
|
enum struct TailNumber
|
|
{
|
|
// Single / Double buffer pipeline
|
|
Odd,
|
|
Even,
|
|
|
|
// Long prefetch pipeline, up to 8
|
|
One,
|
|
Two,
|
|
Three,
|
|
Four,
|
|
Five,
|
|
Six,
|
|
Seven,
|
|
|
|
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
|
Empty,
|
|
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
|
// prefetchstages
|
|
Full,
|
|
};
|
|
template <index_t BlockSize,
|
|
index_t MPerBlock,
|
|
index_t NPerBlock,
|
|
index_t KPerBlock,
|
|
index_t ABufferLoadWidth,
|
|
index_t BBufferLoadWidth,
|
|
index_t ALDSWriteWidth,
|
|
index_t BLDSWriteWidth,
|
|
index_t ALDSReadWidth,
|
|
index_t BLDSReadWidth,
|
|
index_t MRepeat,
|
|
index_t NRepeat,
|
|
index_t MPerXDL,
|
|
index_t NPerXDL,
|
|
index_t KPerXDL>
|
|
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
|
|
{
|
|
static constexpr index_t WaveSize = 64;
|
|
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
|
|
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
|
|
|
|
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
|
|
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
|
|
|
|
static constexpr index_t A_Buffer_Load_Inst_Num =
|
|
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
|
|
static constexpr index_t B_Buffer_Load_Inst_Num =
|
|
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
|
|
|
|
static constexpr index_t A_LDS_Write_Inst_Num =
|
|
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
|
|
static constexpr index_t B_LDS_Write_Inst_Num =
|
|
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
|
|
|
|
static constexpr index_t A_LDS_Read_Inst_Num =
|
|
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
|
|
static constexpr index_t B_LDS_Read_Inst_Num =
|
|
WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
|
|
|
|
static constexpr index_t C_MFMA_Inst_Num =
|
|
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
|
|
|
static constexpr auto Print()
|
|
{
|
|
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
|
|
BlockSize,
|
|
WaveSize,
|
|
MPerBlock,
|
|
NPerBlock,
|
|
KPerBlock,
|
|
MPerXDL,
|
|
NPerXDL,
|
|
KPerXDL);
|
|
|
|
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
|
|
"%d, %d\n C MFMA inst: %d\n",
|
|
A_Buffer_Load_Inst_Num,
|
|
B_Buffer_Load_Inst_Num,
|
|
A_LDS_Write_Inst_Num,
|
|
B_LDS_Write_Inst_Num,
|
|
A_LDS_Read_Inst_Num,
|
|
B_LDS_Read_Inst_Num,
|
|
C_MFMA_Inst_Num);
|
|
}
|
|
};
|
|
|
|
} // namespace ck
|