mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
fixed conflict
This commit is contained in:
@@ -190,8 +190,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t BlockSize = 256;
|
||||
#elif 1
|
||||
// 1x1, 14x14, Vega 10
|
||||
#elif 0
|
||||
// 1x1, 14x14, Vega 20
|
||||
constexpr index_t BPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -219,6 +219,36 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// 1x1, 14x14, Vega 20, hack CPerBlock = 1
|
||||
constexpr index_t BPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 1;
|
||||
|
||||
constexpr index_t BPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
|
||||
@@ -477,9 +477,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread,
|
||||
__device__ void Run_asm(const FloatA* const __restrict__ p_a_block,
|
||||
const FloatB* const __restrict__ p_b_block,
|
||||
FloatC* const __restrict__ p_c_thread,
|
||||
Accumulator f_accum) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
@@ -519,11 +519,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1 && K == 1,
|
||||
"asm is not for this mtx shape");
|
||||
|
||||
const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
//#pragma unroll
|
||||
#if 0
|
||||
#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
@@ -532,9 +539,65 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.NCol(p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
#elif 1
|
||||
// this produce right result
|
||||
using vectorA_t = typename vector_type<FloatA, 4>::MemoryType; // this is float4*
|
||||
|
||||
asm volatile(
|
||||
"\n \
|
||||
ds_read_b128 %0, %1 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread + a_thread_mtx.Get1dIndex(0, 0))))
|
||||
: "v"(__to_local(
|
||||
(void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA))));
|
||||
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vectorA_t*>(
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC))))
|
||||
: "v"(__to_local((
|
||||
void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA))));
|
||||
#elif 0
|
||||
// this produce wrong result
|
||||
using vectorA_t = typename vector_type<FloatA, 4>::MemoryType; // this is float4*
|
||||
|
||||
asm volatile(
|
||||
"\n \
|
||||
ds_read_b128 %0, %2 \n \
|
||||
ds_read_b128 %1, %3 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))),
|
||||
"=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread +
|
||||
a_thread_mtx.Get1dIndex(0, MPerThreadSubC))))
|
||||
: "v"(__to_local(
|
||||
(void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA))),
|
||||
"v"(__to_local((void*)(p_a_block +
|
||||
a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA))));
|
||||
#elif 1
|
||||
// this produce wrong result
|
||||
using vectorA_t = typename vector_type<FloatA, 4>::MemoryType; // this is float4*
|
||||
|
||||
asm volatile(
|
||||
"\n \
|
||||
ds_read_b128 %0, %1 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread + a_thread_mtx.Get1dIndex(0, 0))))
|
||||
: "v"(__to_local((void*)(p_a_block_thread_offset))));
|
||||
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:16 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vectorA_t*>(
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC))))
|
||||
: "v"(__to_local((void*)(p_a_block_thread_offset))));
|
||||
|
||||
#endif
|
||||
|
||||
//#pragma unroll
|
||||
// copy B-sub to form B
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#include "Array.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
@@ -238,7 +238,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#elif 1
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
|
||||
@@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
#elif 0
|
||||
static_assert(NCol == 4, "only for NCol == 4");
|
||||
|
||||
using vector_t = typename vector_type<Float, 4>::MemoryType;
|
||||
@@ -33,15 +33,21 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const index_t src_index = src_mtx.Get1dIndex(i, 0);
|
||||
const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
|
||||
|
||||
#if 1
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
#if 0
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
#elif 0
|
||||
asm volatile("\n \
|
||||
ds_read2_b64 %0, %1 offset1:1 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vector_t*>(&p_dst[dst_index])))
|
||||
: "v"(__to_local((void*)(&p_src[src_index]))));
|
||||
#elif 1
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1, offset:0 \n \
|
||||
"
|
||||
: "=v"(*(reinterpret_cast<vector_t*>(p_dst+dst_index)))
|
||||
: "v"((uint32_t)(p_src + src_index)));
|
||||
ds_read_b128 %0, %1 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(*(reinterpret_cast<vector_t*>(&p_dst[dst_index])))
|
||||
: "v"(__to_local((void*)(&p_src[src_index]))));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user