This commit is contained in:
Jing Zhang
2019-03-28 20:00:31 -05:00
parent 2058bec8cf
commit 5fbf4f33d3

View File

@@ -1,6 +1,8 @@
#pragma once
#include "threadwise_gemm.hip.hpp"
extern "C" __attribute__((address_space(3))) void* __to_local(void* p) [[hc]];
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
@@ -387,9 +389,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
threadwise_matrix_copy(
a_block_mtx,
//MPerLevel1Cluster = 4
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
//MPerThreadSubC = 4
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
}
@@ -398,11 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
const float4* loc = (const float4 *)(p_a_block + src_index);
//const float4* loc = (const float4 *)(p_a_block + src_index);
float4* reg = (float4 *)(p_a_thread + dst_index);
reg[0] = loc[0];
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
//reg[0] = loc[0];
//reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
asm volatile("\n \
ds_read2_b64 %0, %2 offset1:1 \n \
ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(reg[0]), "=v"(reg[1])
: "v"(__to_local((void *)&p_a_block[src_index]))
);
}
#endif