Stream-K CkProfiler Update ( Replace CPU Validation with GPU Validation and Add Dynamic Grid Size Calculation for Stream-K GEMM Profiler) (#2333)

* Stream-K Ckprofiler Update

* new grid list based on sm number

* clang

* update for review

* Update profile_gemm_universal_streamk.cpp

---------

Co-authored-by: root <root@ctr-ubbsmc16.amd.com>
This commit is contained in:
Muhammed Emin Ozturk
2025-06-18 07:49:22 -07:00
committed by GitHub
parent a2f01141aa
commit bfb33bc1e9
2 changed files with 97 additions and 18 deletions

View File

@@ -6,6 +6,7 @@
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include <hip/hip_runtime.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
@@ -133,22 +134,62 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
// Run reference GEMM
if(do_verification)
{
// Use GPU validation
using ReferenceGemmInstanceGPU =
ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType,
ComputeDataType>;
// Use CPU validation
// Note: GPU validation is not supported for fp8 !!!
using ReferenceGemmInstanceCPU = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType>;
auto ref_gemm_cpu = ReferenceGemmInstanceCPU{};
auto ref_invoker_cpu = ref_gemm_cpu.MakeInvoker();
auto ref_argument_cpu = ref_gemm_cpu.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker_cpu.Run(ref_argument_cpu);
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
M,
N,
K,
a_element_op,
b_element_op,
c_element_op);
if(ref_gemm_gpu.IsSupportedArgument(&ref_argument_gpu))
{
ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{nullptr, true});
c_m_n_device_ref_buf.FromDevice(c_m_n_host_result.mData.data());
}
else
{
std::cerr << "GPU reference GEMM does not support this problem configuration so does "
"CPU validation."
<< std::endl;
// Use CPU validation
using ReferenceGemmInstanceCPU =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType>;
auto ref_gemm_cpu = ReferenceGemmInstanceCPU{};
auto ref_invoker_cpu = ref_gemm_cpu.MakeInvoker();
auto ref_argument_cpu = ref_gemm_cpu.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker_cpu.Run(ref_argument_cpu);
}
}
std::string best_op_name;
@@ -158,10 +199,48 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
float best_grid_size = 0;
float best_streamk_sel = 0;
// Get number of SMs on the current GPU
int device_id;
hipError_t err = hipGetDevice(&device_id);
if(err != hipSuccess)
{
std::cerr << "hipGetDevice failed: " << hipGetErrorString(err) << std::endl;
return false;
}
hipDeviceProp_t props;
err = hipGetDeviceProperties(&props, device_id);
if(err != hipSuccess)
{
std::cerr << "hipGetDeviceProperties failed: " << hipGetErrorString(err) << std::endl;
return false;
}
int num_sms = props.multiProcessorCount;
// Generate grid sizes based on SM count with multipliers
std::vector<float> multipliers = {0.2f, 0.4f, 0.6f, 0.8f, 1.0f, 1.2f, 1.4f, 1.6f, 2.0f};
std::vector<int> grid_size_list;
for(float mult : multipliers)
{
int grid_size = static_cast<int>(num_sms * mult);
if(grid_size > 0)
{
grid_size_list.push_back(grid_size);
}
}
std::cout << "Number of SMs: " << num_sms << std::endl;
std::cout << "Grid sizes to test: ";
for(auto gs : grid_size_list)
{
std::cout << gs << " ";
}
std::cout << std::endl;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> grid_size_list = {38, 76, 114, 152, 190, 228, 266, 304, 342, 380};
std::vector<int> streamk_sel_list = {
0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP,
// 2:2-tile Stream-K + DP