mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
This reverts commit 9553c67ab25cb25bf4b6e4d359937413e1f7fd6a.
[ROCm/composable_kernel commit: 1b846143c6]
This commit is contained in:
@@ -4,22 +4,14 @@
|
||||
#include <type_traits>
|
||||
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::half_t>)
|
||||
{
|
||||
constexpr const char* DataTypeToString() {
|
||||
if constexpr (std::is_same_v<T, ck_tile::half_t>) {
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
|
||||
{
|
||||
} else if constexpr (std::is_same_v<T, ck_tile::fp8_t>) {
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
|
||||
{
|
||||
} else if constexpr (std::is_same_v<T, ck_tile::bf8_t>) {
|
||||
return "bf8";
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
@@ -120,9 +112,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time = flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -130,15 +121,18 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>() << " M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename PrecType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_flatmm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -153,7 +147,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
@@ -188,7 +182,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
c_rslt_host.SetZero();
|
||||
|
||||
// do pre-shuffle
|
||||
std::string mfma = arg_parser.get_str("prec");
|
||||
std::string mfma = arg_parser.get_str("prec");
|
||||
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
|
||||
ck_tile::index_t mfma_type = 1;
|
||||
#else
|
||||
@@ -199,18 +193,18 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
invoke_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
bool pass = true;
|
||||
@@ -225,9 +219,8 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
a_host, b_origin_host, c_ref_host);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
c_ref_host,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
@@ -284,9 +277,8 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
c_gpu_ref_host,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
|
||||
@@ -80,7 +80,7 @@ __device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32");
|
||||
|
||||
// buffer atomic-add i32
|
||||
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
@@ -88,7 +88,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
@@ -96,15 +96,15 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
__device__ double
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64(
|
||||
double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32");
|
||||
|
||||
// memory coherency bit for buffer store/load instruction
|
||||
// check ISA manual for each GFX target
|
||||
@@ -827,7 +827,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
|
||||
2559
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
2559
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,11 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#if __clang_major__ == 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
Reference in New Issue
Block a user