mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK][Examples] Extending support for rdna3/4 part 3:
-example_gemm_xdl_int8 -example_gemm_xdl_fp8 -example_gemm_xdl_fp8_bf8 -example_gemm_xdl_fp16_fp8 -example_gemm_add_add_fastgelu_xdl_int8 -example_grouped_gemm_xdl_int8 -example_grouped_conv_bwd_weight_xdl_bf16 -example_cgemm_xdl_fp32 -example_cgemm_xdl_int8 fixing cmdlines for: -example_22_cgemm -example_24_batched_gemm -example_batched_gemm_xdl_fp16int4_b_scale_v3 Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
This commit is contained in:
committed by
Michał Kulikowski
parent
7259b9c4db
commit
2444c44895
@@ -57,4 +57,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run_batched_gemm_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -218,35 +218,37 @@ bool run_batched_gemm_example(int argc, char* argv[])
|
||||
|
||||
problem_size.batch_count = 2;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4 || argc == 8)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 8)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
if(argc == 8)
|
||||
{
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("optinal\n");
|
||||
printf("arg4-7: M = %d N = %d K = %d Batch = %d\n",
|
||||
problem_size.M,
|
||||
problem_size.N,
|
||||
problem_size.K,
|
||||
problem_size.batch_count);
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("optional\n");
|
||||
printf("arg4-7: M, N, K, Batch\n");
|
||||
exit(1);
|
||||
}
|
||||
printf("M = %d N = %d K = %d Batch = %d\n",
|
||||
problem_size.M,
|
||||
problem_size.N,
|
||||
problem_size.K,
|
||||
problem_size.batch_count);
|
||||
|
||||
problem_size.stride_A = problem_size.K;
|
||||
problem_size.stride_B = problem_size.K;
|
||||
|
||||
@@ -523,6 +523,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
@@ -535,30 +540,30 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
|
||||
problem_size.batch_count = 2;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
// use default case
|
||||
}
|
||||
else if(argc >= 7)
|
||||
else if(argc == 4 || argc >= 7)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
if(argc >= 8)
|
||||
if(argc >= 7)
|
||||
{
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
if(argc >= 9)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[8]);
|
||||
if(argc >= 8)
|
||||
{
|
||||
problem_size.batch_count = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
if(argc >= 9)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[8]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -566,6 +571,9 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4-6: problem size (M, N, K)\n");
|
||||
printf("arg7: batch count\n");
|
||||
printf("arg8: KBatch\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user