mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Navi21 gemm (#197)
* start adding navi21 GEMM
* navi_gemm_km_kn_mn_fp32 compiles and passes one test.
* rename variables and functions in gridwise_gemm_dlops_v1r3
* add other 3 layouts; format instance
* adding more tuning parameters
add tuning parameters for other 3 layouts
* add gemm_dlops_f16
* tmp
* add dependence of DeviceGemm::IsSupportedArg() on arch
* minor changes
* minor changes
* minor changes
* minor changes
* minor changes
* minor changes
* minor changes
* push gemm_dlops into profiler
* minor changes
* if using xdl or dlops is moved into profiler_gemm_impl
* minor changes
* minor changes
* remove is_xdl from profile_gemm_impl
* make IsSupportedArg dependent on arch for other device_gemm
* minor changes
* minor changes
* fix a bug in f_generate_tensor_value
* add 64x64x64 for gemm_dlops_int8
* add 64x64x64 for gemm_dlops_int8
* comment out 3 layouts in gemm_dlops_int8; add 32x32x32 for gemm_dlops_int8; init A values to 1
* fix
* start fixing tuning parameters
* monir
* minor changes
* minor changes
* minor changes
* fixing
* adding example
* adding example
* adding example
* add gemm fp32 example
* clean up
* use 128x128x16 as MNK tile in navi21 gemm example
* bug fix
* fix test
* use new block c tile
* clean
* fix build
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: shaojiewang <wsjmessi@163.com>
[ROCm/composable_kernel commit: 40b59a63cc]
This commit is contained in:
@@ -44,14 +44,10 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
@@ -76,6 +72,21 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<Devic
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
@@ -127,7 +138,11 @@ void profile_gemm_impl(int do_verification,
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
// case 0: break;
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
@@ -176,6 +191,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -194,6 +212,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -212,6 +233,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -230,6 +254,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -252,6 +279,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -270,6 +300,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
@@ -291,6 +324,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -309,6 +345,9 @@ void profile_gemm_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -355,28 +394,40 @@ void profile_gemm_impl(int do_verification,
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,7 +576,8 @@ void profile_gemm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "does not support this GEMM problem" << std::endl;
|
||||
std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user