Merge commit '1b846143c669b7bcbba4fda5f9165bb270f88ea1' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-22 23:06:33 +00:00
parent 82f179f094
commit b0752bd085
5 changed files with 2607 additions and 51 deletions

View File

@@ -4,22 +4,14 @@
#include <type_traits>
template <typename T>
constexpr const char* DataTypeToString()
{
if constexpr(std::is_same_v<T, ck_tile::half_t>)
{
constexpr const char* DataTypeToString() {
if constexpr (std::is_same_v<T, ck_tile::half_t>) {
return "fp16";
}
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
{
} else if constexpr (std::is_same_v<T, ck_tile::fp8_t>) {
return "fp8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
{
} else if constexpr (std::is_same_v<T, ck_tile::bf8_t>) {
return "bf8";
}
else
{
} else {
return "unknown";
}
}
@@ -120,9 +112,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time =
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
float ave_time = flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -130,15 +121,18 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>() << " M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
template <typename PrecType,
typename ALayout,
typename BLayout,
typename CLayout>
int run_flatmm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -153,7 +147,7 @@ int run_flatmm_example_with_layouts(int argc,
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
@@ -188,7 +182,7 @@ int run_flatmm_example_with_layouts(int argc,
c_rslt_host.SetZero();
// do pre-shuffle
std::string mfma = arg_parser.get_str("prec");
std::string mfma = arg_parser.get_str("prec");
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
ck_tile::index_t mfma_type = 1;
#else
@@ -199,18 +193,18 @@ int run_flatmm_example_with_layouts(int argc,
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
invoke_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
@@ -225,9 +219,8 @@ int run_flatmm_example_with_layouts(int argc,
a_host, b_origin_host, c_ref_host);
const float max_accumulated_value =
*std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_rslt_host,
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_rslt_host,
c_ref_host,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
@@ -284,9 +277,8 @@ int run_flatmm_example_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float max_accumulated_value =
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_rslt_host,
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_rslt_host,
c_gpu_ref_host,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),

View File

@@ -80,7 +80,7 @@ __device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32");
// buffer atomic-add i32
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
@@ -88,7 +88,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32");
// buffer atomic-add fp32
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
@@ -96,15 +96,15 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32");
// buffer atomic-add fp32
__device__ double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64(
double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32");
// memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
@@ -827,7 +827,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
#ifndef __HIPCC_RTC__
template <typename T, index_t NumElemsPerThread>

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/utility.hpp"

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,11 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#if __clang_major__ == 20
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#else
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#endif
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"