Merge commit '507d81c3af51b81f15b946a2a4bef7f594620292' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-03 20:14:18 +00:00
parent 7ce8c0cf8f
commit a8059a2e58
19 changed files with 777 additions and 172 deletions

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include "ck/utility/data_type.hpp"
namespace ck {
namespace profiler {
template <typename DataType, typename ComputeDataType = DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType, typename ComputeDataType = DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
} // namespace profiler
} // namespace ck

View File

@@ -69,19 +69,19 @@ template <typename A0DataType,
typename ALayout,
typename BLayout,
typename ELayout>
bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideE,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideE,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
bool pass = true;
@@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
// Update strides based on tensor properties if they are <= 0
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
if(current_stride <= 0)
{
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{
return tensor.GetStrides()[0];
}
else
{
return tensor.GetStrides()[1];
}
}
return current_stride;
};
StrideA = get_stride(a0_m_k, ALayout{}, StrideA);
StrideB = get_stride(b0_k_n, BLayout{}, StrideB);
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
int total_gemm_needed =
a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() +
a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes();

View File

@@ -20,6 +20,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "profiler/common.hpp"
namespace ck {
namespace profiler {
@@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
// Update strides based on tensor properties if they are <= 0
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
if(current_stride <= 0)
{
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{
return tensor.GetStrides()[0];
}
else
{
return tensor.GetStrides()[1];
}
}
return current_stride;
};
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0);
StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1);
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
int total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes();
@@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
break;
default:
@@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-3;
double atol = 5e-2;
double rtol = get_rtol<EDataType>();
double atol = get_atol<EDataType>();
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}

View File

@@ -20,6 +20,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "profiler/common.hpp"
namespace ck {
namespace profiler {
@@ -99,6 +100,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// Update strides based on tensor properties if they are <= 0
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
if(current_stride <= 0)
{
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{
return tensor.GetStrides()[0];
}
else
{
return tensor.GetStrides()[1];
}
}
return current_stride;
};
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC);
std::size_t total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
@@ -317,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
is_same_v<CDataType, f8_t>)
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
double rtol = get_rtol<CDataType>();
double atol = get_atol<CDataType>();
pass = pass & ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
}

View File

@@ -5,92 +5,11 @@
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "profiler/common.hpp"
namespace ck {
namespace profiler {
template <typename DataType>
inline constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,