mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
[Navi3x-LWPCK-545] Block-wise GEMM + Real GEMM_WMMA_FP16 (#541)
* wmma_op + unit test
* add arch limitation to wmma test
* change arch limitation
* Refactor + Add all type unit test(int4 compile failed)
* Add f32_16x16x16_bf16 unit test
* tempsave
* tempsave
* tempsave
* runtime bug, cannot find symbol
* workaround for incorrect HIP warpSize return value
* debugging
* tempsave
* Correctness OK, waiting for optimization
* Tidy up + format
* temp save
* temp save, reproduce the v_bfi_b32 issue
* add inline asm for wmmaop test
* tidy up
* clean some debug purpose code
* discard some codes
* clang format
* clang format
* compiler issue fixed + increase tile size
[ROCm/composable_kernel commit: 919aeb1f52]
This commit is contained in:
@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
|
||||
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
|
||||
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
||||
{
|
||||
__shared__ src_t p_shared[16 * 16 * 2];
|
||||
const int lIdx = threadIdx.x;
|
||||
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
|
||||
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
|
||||
@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
||||
using src_vec = typename vector_type<src_t, 16>::type;
|
||||
src_vec a_frag = {};
|
||||
src_vec b_frag = {};
|
||||
|
||||
src_vec a_temp = {};
|
||||
src_vec b_temp = {};
|
||||
// initialize c fragment to 0
|
||||
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>;
|
||||
acc_vec c_thread_buf_;
|
||||
@@ -111,21 +115,57 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
||||
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
|
||||
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
|
||||
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
|
||||
const int lane = lIdx % 16;
|
||||
const int lane = lIdx % 16;
|
||||
const int lane_lo = lIdx / 2;
|
||||
const int lane_hi = lIdx % 2;
|
||||
for(int ele = 0; ele < 8; ++ele)
|
||||
{
|
||||
a_temp[ele] = a[8 * lane_hi + 16 * lane_lo + ele];
|
||||
}
|
||||
|
||||
for(int ele = 0; ele < 8; ++ele)
|
||||
{
|
||||
b_temp[ele] = b[8 * lane_hi + 16 * lane_lo + ele];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(int ele = 0; ele < 8; ++ele)
|
||||
{
|
||||
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele] = a_temp[ele];
|
||||
}
|
||||
|
||||
for(int ele = 0; ele < 8; ++ele)
|
||||
{
|
||||
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele];
|
||||
}
|
||||
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
|
||||
for(int ele = 0; ele < 16; ++ele)
|
||||
{
|
||||
b_frag[ele] = b[16 * lane + ele];
|
||||
b_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8 + 16 * 16];
|
||||
}
|
||||
// follow origin design
|
||||
for(int ele = 0; ele < 16; ++ele)
|
||||
{
|
||||
a_frag[ele] = a[16 * lane + ele];
|
||||
a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8];
|
||||
}
|
||||
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
|
||||
// sync threads, similar to mma_sync
|
||||
__syncthreads();
|
||||
// __syncthreads();
|
||||
builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_);
|
||||
// since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16
|
||||
// when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag,
|
||||
// c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType<float8_t>()(Number<0>{}));
|
||||
__syncthreads();
|
||||
// wait for results, similar to mma_sync
|
||||
static_for<0, 8, 1>{}([&](auto ele) {
|
||||
|
||||
Reference in New Issue
Block a user