mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
[GEMM] Gemm universal device operation (#1154)
* Optimize GEMM on MI200/300:
1. Add new blockwise gemm pipeline
2. Add irregular splitk intances
* clang format + typo fix
* Fix a bug
* initial commit
* Add more instances to irregular splitk
* blkgemm pipeline v1~4 prototype
* Sanity Checked. Known issue:
1. Poor performance of splitk
2. Register spill on blkgemmpipeline v3
* Sanity and Performance fix:
1. fix a bug related to sanity in grouped b2c mapping
2. fix a bug related to sanity and performance in splitk offset
* Sanity and API update:
1. Remove prefetch stage
2. Fix valid check bug
3, Add first gemm_universal instance into ckProfiler
* Add NN instances for gemm universal
* 1. Add NT instances for gemm_universal
2. Fix a bug about Kpadding in gemm_universal
* Fix a bug regarding padding Odd K number
* remove kernel print
* Fix KPadding bug...
* Update safety check
* another try to fix kpadding..
* Sanity checked
* new instances..
* clang format+typo fix
* remove clang format script's change
* Add non-hotloop compile option
* 1. Add fp16xfp8 example
2. pull packed convert f8 from pr1150
* Some miscs.. opt and fix
* Add pipeline description docs
* Split universal gemm instance library to cut profiler compiling time
* uncomment cmakefile
* Fix a bug caused by blockwise_gemm_pipe_v2
* reduce default splitk to 1
* Add 224x256x64 tile size
* update, including:
1. Experiment pipeline 5~7
2. Optimization for pipeline 4
3. Organized instance library
* temp save
* temp save
* Permuted lds layout, sanity and function checked
* clang format
* Move OOB check from RunRead to RunWrite, for better software pipeline.
TODO: agpr spill when NN layout
* clangformat
* A/B splitpipe scheduler for v3
* Fix two bugs
* bug fix
* fix a bug in oob check
* Example for mixed fp16_fp8 gemm
* Clean experimental code blocks
* Add mixed precision gemm into profiler
* tempsave
* optimize m/n major lds layout
* Add RRR GEMM mixed precision instances
* Optimize f8 matrix transpose
* Add test_gemm_universal
* A/B spilt schedule for blkpip v5
* Take ds_read2 into iglp scheduling scheme
* format
* fixed cmake
* Add llvm-option into CI cmake flag
---------
Co-authored-by: Jing Zhang <jizhan@amd.com>
[ROCm/composable_kernel commit: f83e9701e9]
This commit is contained in:
104
include/ck/utility/blkgemmpipe_scheduler.hpp
Normal file
104
include/ck/utility/blkgemmpipe_scheduler.hpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct BlockGemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
enum struct TailNumber
|
||||
{
|
||||
// Single / Double buffer pipeline
|
||||
Odd,
|
||||
Even,
|
||||
|
||||
// Long prefetch pipeline, up to 8
|
||||
One,
|
||||
Two,
|
||||
Three,
|
||||
Four,
|
||||
Five,
|
||||
Six,
|
||||
Seven,
|
||||
|
||||
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
||||
Empty,
|
||||
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
||||
// prefetchstages
|
||||
Full,
|
||||
};
|
||||
template <index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t ABufferLoadWidth,
|
||||
index_t BBufferLoadWidth,
|
||||
index_t ALDSWriteWidth,
|
||||
index_t BLDSWriteWidth,
|
||||
index_t ALDSReadWidth,
|
||||
index_t BLDSReadWidth,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t KPerXDL>
|
||||
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
|
||||
{
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
|
||||
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
|
||||
|
||||
static constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
|
||||
static constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
|
||||
|
||||
static constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
|
||||
static constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
|
||||
|
||||
static constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
|
||||
static constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
|
||||
|
||||
static constexpr index_t C_MFMA_Inst_Num =
|
||||
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
static constexpr auto Print()
|
||||
{
|
||||
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
|
||||
BlockSize,
|
||||
WaveSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
KPerXDL);
|
||||
|
||||
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
|
||||
"%d, %d\n C MFMA inst: %d\n",
|
||||
A_Buffer_Load_Inst_Num,
|
||||
B_Buffer_Load_Inst_Num,
|
||||
A_LDS_Write_Inst_Num,
|
||||
B_LDS_Write_Inst_Num,
|
||||
A_LDS_Read_Inst_Num,
|
||||
B_LDS_Read_Inst_Num,
|
||||
C_MFMA_Inst_Num);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -163,6 +163,13 @@ struct scalar_type<bf8_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bool>
|
||||
{
|
||||
using type = bool;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
{
|
||||
|
||||
@@ -10,10 +10,12 @@ namespace ck {
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
@@ -162,4 +162,83 @@ struct transpose_vectors<int8_t, NX, NY>
|
||||
}
|
||||
};
|
||||
|
||||
// transpose f8 4x4
|
||||
__device__ void transpose_f8_4x4(const f8x4_t& x0,
|
||||
const f8x4_t& x1,
|
||||
const f8x4_t& x2,
|
||||
const f8x4_t& x3,
|
||||
f8x4_t& y0,
|
||||
f8x4_t& y1,
|
||||
f8x4_t& y2,
|
||||
f8x4_t& y3)
|
||||
{
|
||||
int32_t t0, t1;
|
||||
int32_t z0, z1, z2, z3;
|
||||
constexpr int32_t m0 = 0x05010400;
|
||||
constexpr int32_t m1 = 0x05040100;
|
||||
constexpr int32_t m2 = 0x07060302;
|
||||
constexpr int32_t m3 = 0x07030602;
|
||||
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
|
||||
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
|
||||
z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
|
||||
z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
|
||||
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
|
||||
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
|
||||
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
|
||||
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
|
||||
|
||||
y0 = bit_cast<f8x4_t>(z0);
|
||||
y1 = bit_cast<f8x4_t>(z1);
|
||||
y2 = bit_cast<f8x4_t>(z2);
|
||||
y3 = bit_cast<f8x4_t>(z3);
|
||||
}
|
||||
|
||||
template <index_t NX, index_t NY>
|
||||
struct transpose_vectors<f8_t, NX, NY>
|
||||
{
|
||||
// we got [NY * NX] amount of S data to be transposed
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = f8_t;
|
||||
using VX = vector_type<f8_t, s_per_x>;
|
||||
using VY = vector_type<f8_t, s_per_y>;
|
||||
|
||||
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
|
||||
StaticallyIndexedArray<VY&, NY>& vy_tuple)
|
||||
{
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
|
||||
|
||||
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 4>{}([&](auto iy) {
|
||||
static_for<0, NX, 4>{}([&](auto ix) {
|
||||
// reference to 4 f8 data from vx_tuple
|
||||
const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
|
||||
const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
|
||||
const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
|
||||
const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
|
||||
|
||||
// reference to 4 f8 data from vy_tuple
|
||||
auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
|
||||
auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
|
||||
auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
|
||||
auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
|
||||
|
||||
// transpose
|
||||
transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -43,6 +43,8 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
|
||||
Y y;
|
||||
|
||||
// auto t = reinterpret_cast<const Y*>(&x);
|
||||
// y = *t;
|
||||
__builtin_memcpy(&y, &x, sizeof(X));
|
||||
|
||||
return y;
|
||||
|
||||
Reference in New Issue
Block a user