mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
inline
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
#include "threadwise_gemm.hip.hpp"
|
||||
|
||||
extern "C" __attribute__((address_space(3))) void* __to_local(void* p) [[hc]];
|
||||
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
@@ -387,9 +389,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
//MPerLevel1Cluster = 4
|
||||
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
//MPerThreadSubC = 4
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
@@ -398,11 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
|
||||
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
|
||||
|
||||
const float4* loc = (const float4 *)(p_a_block + src_index);
|
||||
//const float4* loc = (const float4 *)(p_a_block + src_index);
|
||||
float4* reg = (float4 *)(p_a_thread + dst_index);
|
||||
|
||||
reg[0] = loc[0];
|
||||
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
|
||||
//reg[0] = loc[0];
|
||||
//reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
|
||||
asm volatile("\n \
|
||||
ds_read2_b64 %0, %2 offset1:1 \n \
|
||||
ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(reg[0]), "=v"(reg[1])
|
||||
: "v"(__to_local((void *)&p_a_block[src_index]))
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user