mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Navi21 gemm (#197)
* start adding navi21 GEMM * navi_gemm_km_kn_mn_fp32 compiles and passes one test. * rename variables and functions in gridwise_gemm_dlops_v1r3 * add other 3 layouts; format instance * adding more tuning parameters add tuning parameters for other 3 layouts * add gemm_dlops_f16 * tmp * add dependence of DeviceGemm::IsSupportedArg() on arch * minor changes * minor changes * minor changes * minor changes * minor changes * minor changes * minor changes * push gemm_dlops into profiler * minor changes * if using xdl or dlops is moved into profiler_gemm_impl * minor changes * minor changes * remove is_xdl from profile_gemm_impl * make IsSupportedArg dependent on arch for other device_gemm * minor changes * minor changes * fix a bug in f_generate_tensor_value * add 64x64x64 for gemm_dlops_int8 * add 64x64x64 for gemm_dlops_int8 * comment out 3 layouts in gemm_dlops_int8; add 32x32x32 for gemm_dlops_int8; init A values to 1 * fix * start fixing tuning parameters * monir * minor changes * minor changes * minor changes * fixing * adding example * adding example * adding example * add gemm fp32 example * clean up * use 128x128x16 as MNK tile in navi21 gemm example * bug fix * fix test * use new block c tile * clean * fix build Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: shaojiewang <wsjmessi@163.com>
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "threadwise_contraction_dl.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -41,7 +39,7 @@ template <index_t BlockSize,
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
@@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
__device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
@@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
"wrong!");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2, "wrong");
|
||||
static_assert(BM0 == 2 && BN0 == 2, "wrong");
|
||||
}
|
||||
|
||||
@@ -226,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
@@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v5r1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user