mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Split k f16 (#97)
* init for splitk f16
* a working prototype
* debug
* perf debug
* update example
* instances for mk kn
* add instances for all layers
* clean
* clean
* add tuning
* format
* add mn_padding into irregular tile
* clean
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: e221d11e51]
This commit is contained in:
@@ -44,6 +44,11 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
@@ -68,7 +73,7 @@ void profile_gemm_impl(int do_verification,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int KBatch = 1)
|
||||
int KBatch)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -181,7 +186,6 @@ void profile_gemm_impl(int do_verification,
|
||||
{
|
||||
if(KBatch > 1)
|
||||
{
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
@@ -214,44 +218,76 @@ void profile_gemm_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
if(KBatch > 1)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -78,7 +78,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -97,7 +98,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -116,7 +118,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -135,7 +138,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user