mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Merge commit '507d81c3af51b81f15b946a2a4bef7f594620292' into develop
This commit is contained in:
103
profiler/include/profiler/common.hpp
Normal file
103
profiler/include/profiler/common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -69,19 +69,19 @@ template <typename A0DataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout>
|
||||
bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideE,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideE,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a0_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b0_k_n, BLayout{}, StrideB);
|
||||
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
|
||||
|
||||
int total_gemm_needed =
|
||||
a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() +
|
||||
a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes();
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
|
||||
StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0);
|
||||
StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1);
|
||||
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
|
||||
|
||||
int total_gemm_needed =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
|
||||
d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes();
|
||||
@@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
|
||||
break;
|
||||
default:
|
||||
@@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
is_same_v<EDataType, int8_t>))
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-3;
|
||||
double atol = 5e-2;
|
||||
double rtol = get_rtol<EDataType>();
|
||||
double atol = get_atol<EDataType>();
|
||||
pass = pass & ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -99,6 +100,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
|
||||
StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC);
|
||||
|
||||
std::size_t total_gemm_needed =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
|
||||
int rotating_count = std::max(
|
||||
@@ -317,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
is_same_v<CDataType, f8_t>)
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-1;
|
||||
double atol = 1e-1;
|
||||
double rtol = get_rtol<CDataType>();
|
||||
double atol = get_atol<CDataType>();
|
||||
pass = pass & ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
|
||||
@@ -5,92 +5,11 @@
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
Reference in New Issue
Block a user