diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 5aa978fbf0..3b21f95119 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -91,6 +91,8 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; + ck::index_t KBatch = 1; + if(argc == 1) { // use default case @@ -101,7 +103,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 8) + else if(argc == 8 || argc == 9) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -113,6 +115,11 @@ int main(int argc, char* argv[]) flush_cache = std::stoi(argv[7]); + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + StrideA = K; StrideB = K; StrideE = N; @@ -124,6 +131,7 @@ int main(int argc, char* argv[]) printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 6: M, N, K\n"); printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); exit(0); } @@ -233,9 +241,9 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{}, e_device_buf.GetDeviceBuffer(), @@ -251,6 +259,7 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, cde_element_op); + argument.KBatch = KBatch; if(!device_op.IsSupportedArgument(argument)) { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp index abf49bdab2..073f4541b1 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -58,6 +58,8 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator CDEElementwiseOperation cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0; }; template (base_arg); + arg.KBatch = KBatch; + } + static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check diff --git a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp index f17516a47d..3c511469f2 100644 --- a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp @@ -47,6 +47,7 @@ bool profile_gemm_ab_scale_impl(int do_verification, int StrideA, int StrideB, int StrideE, + int KBatch, int n_warmup, int n_iter, uint64_t rotating = 0) @@ -238,6 +239,7 @@ bool profile_gemm_ab_scale_impl(int do_verification, a_element_op, b_element_op, c_element_op); + op_ptr->SetKBatch(argument_ptr.get(), KBatch); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/profiler/src/profile_gemm_ab_scale.cpp b/profiler/src/profile_gemm_ab_scale.cpp index 531872bbb9..c2889d5490 100644 --- a/profiler/src/profile_gemm_ab_scale.cpp +++ b/profiler/src/profile_gemm_ab_scale.cpp @@ -40,7 +40,7 @@ enum struct ScaleBlockTile int profile_gemm_ab_scale(int argc, char* argv[]) { - if(argc != 15 && argc != 18) + if(argc != 15 && argc != 16 && argc != 19) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " @@ -57,10 +57,11 @@ int profile_gemm_ab_scale(int argc, char* argv[]) printf("arg7: print tensor value (0: no; 1: yes)\n"); printf("arg8: time kernel (0=no, 1=yes)\n"); printf("arg9 to 14: M, N, K, StrideA, StrideB, StrideE\n"); + printf("arg15: KBatch (default: 1)\n"); printf("optional:\n"); - printf("arg15: number of warm-up cycles (default 1)\n"); - printf("arg16: number of iterations (default 10)\n"); - printf("arg17: memory for rotating buffer (default 0, size in MB)\n"); + printf("arg16: number of warm-up cycles (default 1)\n"); + printf("arg17: number of iterations (default 10)\n"); + printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); exit(1); } @@ -79,15 +80,16 @@ int profile_gemm_ab_scale(int argc, char* argv[]) const int StrideA = std::stoi(argv[12]); const int StrideB = std::stoi(argv[13]); const int StrideE = std::stoi(argv[14]); + const int KBatch = (argc > 15) ? std::stoi(argv[15]) : 1; int n_warmup = 1; int n_iter = 10; uint64_t rotating = 0; - if(argc == 18) + if(argc == 19) { - n_warmup = std::stoi(argv[15]); - n_iter = std::stoi(argv[16]); - rotating = std::stoull(argv[17]) * 1024 * 1024; + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + rotating = std::stoull(argv[18]) * 1024 * 1024; } using F32 = float; @@ -149,6 +151,7 @@ int profile_gemm_ab_scale(int argc, char* argv[]) (StrideA < 0) ? DefaultStrideA : StrideA, (StrideB < 0) ? DefaultStrideB : StrideB, (StrideE < 0) ? DefaultStrideE : StrideE, + KBatch, n_warmup, n_iter, rotating);