mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK TILE ENGINE] Adding GEMM Preshuffle to CK Tile Engine (#2712)
* Partial Progress : Completed ListBlob * Additional changes in Listbob * Partial Progress : Generate Blobs Completed * Partial Progress : Added Host side code for Preshuffle * Working code for Preshuffle before Cleanup * Partial Progress : Cleanup * Partial Progress : Datatype Validation * Partial Progress : Warptiles for preshuffle changed from hardcoding to take from config * Partial Progress : Cleanup * Partial Progress : Code Cleanup * Partial Progress : Passing all valid tiles failing for unsupported tiles * Partial Progress : Working code, testing pending for edge cases * Partial Progress for testing * Completed Code * kBlockPerCu as tunable parameter from config * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update tile_engine/ops/gemm_preshuffle/README.md Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Partial Progress : Working listkernels * Partial Progress : Cleanup Working listkernels * Partial Progress : Single instance * Partial Progress : Working single instance code * Partial Progress : Working generate individual instance code * Partial Progress : Working rewamped code for given config file needed validation and edge case testing * Partial Progress : Working Code, testing pending * Removing LOGS file * Working code * Minor changes to GEMM Preshuffle : Restructured * Minor Changes in Preshuffle * Changes to Jenkins File * Changes to Jenkins file to consider new architecture * Changes to Jenkins file for fixing CI --------- Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
committed by
GitHub
parent
1894a0dbc3
commit
f6ba94fb5c
275
tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp
Normal file
275
tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp
Normal file
@@ -0,0 +1,275 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm_preshuffle.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
public:
|
||||
static GemmProfiler& instance(Setting setting)
|
||||
{
|
||||
static GemmProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::function<float(const ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>
|
||||
kernel_func,
|
||||
const std::tuple<int, int, int>& warp_tile_dims)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&,
|
||||
const ck_tile::stream_config&)>>
|
||||
callables;
|
||||
|
||||
callables.push_back(
|
||||
[kernel_func](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {
|
||||
float time = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
});
|
||||
|
||||
benchmark(gemm_problem, callables, warp_tile_dims);
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables,
|
||||
const std::tuple<int, int, int>& warp_tile_dims)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
gemm_problem.stride_a_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
|
||||
gemm_problem.stride_b_ = ck_tile::get_default_stride(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
|
||||
gemm_problem.stride_c_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.init_method_ == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
// Reference Verification
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
c_m_n_ref.SetZero();
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_host_reference(setting_.verify_,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_ref,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_);
|
||||
}
|
||||
|
||||
// Kerenl Execution
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
for(const auto& callable : callables)
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(warp_tile_dims);
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.split_k_,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_,
|
||||
};
|
||||
|
||||
auto kernel_run_result = callable(gemm_args,
|
||||
ck_tile::stream_config{nullptr,
|
||||
true,
|
||||
setting_.log_,
|
||||
setting_.n_warmup_,
|
||||
setting_.n_repeat_,
|
||||
setting_.is_gpu_timer_,
|
||||
setting_.flush_cache_,
|
||||
setting_.rotating_count_});
|
||||
|
||||
process_result(
|
||||
gemm_problem, c_m_n_dev_buf, c_m_n_ref, c_m_n_dev_result, kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmProblem& gemm_problem,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
// compute performance metric
|
||||
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
|
||||
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
|
||||
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
|
||||
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
|
||||
|
||||
// update
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// verify result
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_ref);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// clear tensor
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name,"
|
||||
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
||||
<< "structured_sparsity," << "name,"
|
||||
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
|
||||
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
|
||||
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
|
||||
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmProfiler() { kernel_instances_.clear(); }
|
||||
GemmProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
Reference in New Issue
Block a user