mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Stream-K CkProfiler Update ( Replace CPU Validation with GPU Validation and Add Dynamic Grid Size Calculation for Stream-K GEMM Profiler) (#2333)
* Stream-K Ckprofiler Update * new grid list based on sm number * clang * update for review * Update profile_gemm_universal_streamk.cpp --------- Co-authored-by: root <root@ctr-ubbsmc16.amd.com>
This commit is contained in:
committed by
GitHub
parent
a2f01141aa
commit
bfb33bc1e9
111
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Normal file → Executable file
111
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Normal file → Executable file
@@ -6,6 +6,7 @@
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -133,22 +134,62 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
// Use GPU validation
|
||||
using ReferenceGemmInstanceGPU =
|
||||
ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeDataType,
|
||||
ComputeDataType>;
|
||||
|
||||
// Use CPU validation
|
||||
// Note: GPU validation is not supported for fp8 !!!
|
||||
using ReferenceGemmInstanceCPU = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeDataType>;
|
||||
auto ref_gemm_cpu = ReferenceGemmInstanceCPU{};
|
||||
auto ref_invoker_cpu = ref_gemm_cpu.MakeInvoker();
|
||||
auto ref_argument_cpu = ref_gemm_cpu.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
ref_invoker_cpu.Run(ref_argument_cpu);
|
||||
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
|
||||
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
|
||||
auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(ref_gemm_gpu.IsSupportedArgument(&ref_argument_gpu))
|
||||
{
|
||||
ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{nullptr, true});
|
||||
c_m_n_device_ref_buf.FromDevice(c_m_n_host_result.mData.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "GPU reference GEMM does not support this problem configuration so does "
|
||||
"CPU validation."
|
||||
<< std::endl;
|
||||
|
||||
// Use CPU validation
|
||||
|
||||
using ReferenceGemmInstanceCPU =
|
||||
ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeDataType>;
|
||||
auto ref_gemm_cpu = ReferenceGemmInstanceCPU{};
|
||||
auto ref_invoker_cpu = ref_gemm_cpu.MakeInvoker();
|
||||
auto ref_argument_cpu = ref_gemm_cpu.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
ref_invoker_cpu.Run(ref_argument_cpu);
|
||||
}
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
@@ -158,10 +199,48 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
|
||||
float best_grid_size = 0;
|
||||
float best_streamk_sel = 0;
|
||||
|
||||
// Get number of SMs on the current GPU
|
||||
int device_id;
|
||||
hipError_t err = hipGetDevice(&device_id);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
std::cerr << "hipGetDevice failed: " << hipGetErrorString(err) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
hipDeviceProp_t props;
|
||||
err = hipGetDeviceProperties(&props, device_id);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
std::cerr << "hipGetDeviceProperties failed: " << hipGetErrorString(err) << std::endl;
|
||||
return false;
|
||||
}
|
||||
int num_sms = props.multiProcessorCount;
|
||||
|
||||
// Generate grid sizes based on SM count with multipliers
|
||||
std::vector<float> multipliers = {0.2f, 0.4f, 0.6f, 0.8f, 1.0f, 1.2f, 1.4f, 1.6f, 2.0f};
|
||||
std::vector<int> grid_size_list;
|
||||
|
||||
for(float mult : multipliers)
|
||||
{
|
||||
int grid_size = static_cast<int>(num_sms * mult);
|
||||
if(grid_size > 0)
|
||||
{
|
||||
grid_size_list.push_back(grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Number of SMs: " << num_sms << std::endl;
|
||||
std::cout << "Grid sizes to test: ";
|
||||
for(auto gs : grid_size_list)
|
||||
{
|
||||
std::cout << gs << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::vector<int> grid_size_list = {38, 76, 114, 152, 190, 228, 266, 304, 342, 380};
|
||||
std::vector<int> streamk_sel_list = {
|
||||
0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP,
|
||||
// 2:2-tile Stream-K + DP
|
||||
|
||||
Reference in New Issue
Block a user