mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
@@ -16,7 +16,9 @@ template <index_t BlockSize,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t BatchPerThread>
|
||||
index_t BatchPerThread,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
// copy B-sub to form B
|
||||
@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// loop over batch
|
||||
@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@ template <index_t BlockSize,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop>
|
||||
index_t KPerThreadLoop,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
@@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
@@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_0 + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_0 + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
bool even_loop = true;
|
||||
@@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_next + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_next + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
|
||||
@@ -30,6 +30,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
@@ -169,7 +171,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread>{};
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
@@ -30,6 +30,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
@@ -172,7 +174,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread>{};
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
@@ -30,6 +30,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
@@ -173,7 +175,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread>{};
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
// register
|
||||
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
|
||||
|
||||
@@ -26,6 +26,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop>{};
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
|
||||
@@ -27,6 +27,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop>{};
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
|
||||
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol, index_t DataPerRead>
|
||||
__device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const Float* __restrict__ p_src,
|
||||
DstMatrix,
|
||||
Float* __restrict__ p_dst,
|
||||
Sequence<NRow, NCol>)
|
||||
Sequence<NRow, NCol>,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr auto src_mtx = SrcMatrix{};
|
||||
constexpr auto dst_mtx = DstMatrix{};
|
||||
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < NCol; ++j)
|
||||
for(index_t j = 0; j < NCol; j += DataPerRead)
|
||||
{
|
||||
const index_t src_index = src_mtx.Get1dIndex(i, j);
|
||||
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user