Integrate universal gemm with conv bwd data and add SplitK (#1315)

* Integrate universal gemm with conv bwd data

* Fix multi d kernel

* Add splitK support

* instances refactor

* instances refactor

* refactor

* fixeS

* fixes

* 16x16 instnaces

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Fixes

* fix

* fix

[ROCm/composable_kernel commit: 4094ad158a]
This commit is contained in:
Bartłomiej Kocot
2025-04-28 23:54:49 +02:00
committed by GitHub
parent 02ef8bcfb1
commit 05f9b2dde3
69 changed files with 2262 additions and 349 deletions

View File

@@ -68,8 +68,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
const bool time_kernel = std::stoi(argv[7]);
const int num_dim_spatial = std::stoi(argv[8]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial)
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
{
print_helper_msg();
return 1;
@@ -77,6 +77,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
@@ -110,7 +112,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
OutDataType,
WeiDataType,
InDataType>(
do_verification, init_method, do_log, time_kernel, params);
do_verification, init_method, do_log, time_kernel, params, split_k);
return pass ? 0 : 1;
};