mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
refactor inline asm
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "inline_asm.hpp"
|
||||
|
||||
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
|
||||
__device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const Float* __restrict__ p_src,
|
||||
@@ -21,18 +23,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
#else
|
||||
static_assert(NCol == 4, "only for NCol == 4");
|
||||
|
||||
using vector_t = typename vector_type<Float, 4>::MemoryType;
|
||||
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
{
|
||||
const index_t src_index = src_mtx.Get1dIndex(i, 0);
|
||||
const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
Float4 *reg_p = (Float4 *)&p_dst[dst_index];
|
||||
Float4 *loc_p = (Float4 *)&p_src[src_index];
|
||||
|
||||
ds_read_b128(reg_p[0], (void *)&loc_p[0]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -70,25 +72,20 @@ __device__ void threadwise_gemm(MatrixA,
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t i = 0; i < M; ++i)
|
||||
for(index_t i = 0; i < M; i+=4)
|
||||
{
|
||||
for(index_t j = 0; j < N; ++j)
|
||||
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const Float4 *a_vec = (const Float4 *)&p_a_thread[aindex];
|
||||
|
||||
for(index_t j = 0; j < N; j+=4)
|
||||
{
|
||||
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_mtx.Get1dIndex(k, j);
|
||||
const index_t cindex = c_mtx.Get1dIndex(i, j);
|
||||
|
||||
#if 0
|
||||
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
|
||||
#elif 1
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(p_c_thread[cindex])
|
||||
: "v"(p_a_thread[aindex]),
|
||||
"v"(p_b_thread[bindex]),
|
||||
"0"(p_c_thread[cindex]));
|
||||
#endif
|
||||
const Float4 *b_vec = (const Float4 *)&p_b_thread[bindex];
|
||||
Float4 *c_vec = (Float4 *)&p_c_thread[cindex];
|
||||
|
||||
outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user