mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
refactored implicit gemm v1r3
This commit is contained in:
@@ -170,9 +170,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
__device__ void Run_source(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
@@ -189,10 +189,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
@@ -480,6 +480,19 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
|
||||
{
|
||||
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
|
||||
#else
|
||||
Run_source(p_a_block, p_b_block, p_c_thread);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
|
||||
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
|
||||
FloatC* __restrict__ p_c_block) const
|
||||
|
||||
Reference in New Issue
Block a user