mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
aggregate device macros in ck_tile config header (#1297)
[ROCm/composable_kernel commit: 06b891c5c2]
This commit is contained in:
@@ -3,6 +3,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
@@ -109,15 +124,13 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)) // for GPU code
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
@@ -137,13 +150,12 @@
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
|
||||
defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#elif defined(__gfx103__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
|
||||
#elif defined(__gfx11__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t
|
||||
{
|
||||
static constexpr int exponent = 4;
|
||||
static constexpr int mantissa = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
@@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t
|
||||
{
|
||||
static constexpr int exponent = 5;
|
||||
static constexpr int mantissa = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
@@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
@@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
@@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
@@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
|
||||
}
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
@@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
@@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
@@ -656,7 +656,7 @@ struct numeric_traits<fp8_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 8;
|
||||
#else
|
||||
static constexpr int bias = 7;
|
||||
@@ -668,7 +668,7 @@ struct numeric_traits<bf8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 16;
|
||||
#else
|
||||
static constexpr int bias = 15; // IEEE
|
||||
|
||||
@@ -112,7 +112,7 @@ namespace impl {
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user