Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)

* Support bf16/fb8/bf8 datatypes for ck_tile/gemm

* remove commented out code.

* Addressing code review comments and enabling universal_gemm for all the supported data types.

* Merge conflict resolution.

* Solve the memory pipeline compilation error. Merge with the new change of CShuffle

* finish the feature, pass the tests

* Fix the pipeline and add the benchmark script for other data types

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>

[ROCm/composable_kernel commit: ab5d027866]
This commit is contained in:
kylasa
2025-02-06 14:07:38 -08:00
committed by Sam Wu
parent f5d3690565
commit 0aee5c2d16
21 changed files with 598 additions and 88 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
@@ -8,16 +8,75 @@
namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
template <typename T, typename ComputeType>
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
{
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
}
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{
bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
{
bf16x4_t rtn;
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
rtn[2] = add<bf16_t, float>(a[2], b[2]);
rtn[3] = add<bf16_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
{
fp8x8_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
rtn[4] = add<fp8_t, float>(a[4], b[4]);
rtn[5] = add<fp8_t, float>(a[5], b[5]);
rtn[6] = add<fp8_t, float>(a[6], b[6]);
rtn[7] = add<fp8_t, float>(a[7], b[7]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
{
bf8x4_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
{
bf8x8_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
rtn[4] = add<bf8_t, float>(a[4], b[4]);
rtn[5] = add<bf8_t, float>(a[5], b[5]);
rtn[6] = add<bf8_t, float>(a[6], b[6]);
rtn[7] = add<bf8_t, float>(a[7], b[7]);
return rtn;
}
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union U64BF164_ADDR
{
uint64_t* u64_a;
bf16x4_t* bf164_a;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union U64BF164
{
uint64_t u64;
bf16x4_t bf164;
};
U64BF164_ADDR addr;
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164 cur_v;
cur_v.u64 = *addr.u64_a;
U64BF164 new_v_union;
uint64_t old_v, new_v;
do
{
// old 64 bits
old_v = cur_v.u64;
// Add elementwise in bf16
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
new_v = new_v_union.u64;
// Attempt the 64-bit CAS
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
{
union U32FP84_ADDR
{
uint32_t* u32_a;
fp8x4_t* fp84_a;
};
union U32FP84
{
uint32_t u32;
fp8x4_t fp84;
};
U32FP84_ADDR dword_addr;
U32FP84 cur_v;
U32FP84 new_;
uint32_t old_v, new_v;
dword_addr.fp84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
{
union U32BF84_ADDR
{
uint32_t* u32_a;
bf8x4_t* bf84_a;
};
union U32BF84
{
uint32_t u32;
bf8x4_t bf84;
};
U32BF84_ADDR dword_addr;
U32BF84 cur_v;
U32BF84 new_;
uint32_t old_v, new_v;
dword_addr.bf84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
//
// Atomic add for fp8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union U64FP88_ADDR
{
uint64_t* u64_a; // pointer to 64-bit integer
fp8x8_t* fp88_a; // pointer to fp8x8_t
};
union U64FP88
{
uint64_t u64;
fp8x8_t fp88;
};
U64FP88_ADDR dword_addr;
U64FP88 cur_v;
U64FP88 new_v_union;
uint64_t old_v, new_v;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr.fp88_a = p_dst;
// Initial read of 64 bits from memory
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
new_v = new_v_union.u64;
// Attempt 64-bit CAS
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
//
// Atomic add for bf8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
{
union U64BF88_ADDR
{
uint64_t* u64_a;
bf8x8_t* bf88_a;
};
union U64BF88
{
uint64_t u64;
bf8x8_t bf88;
};
U64BF88_ADDR dword_addr;
U64BF88 cur_v;
U64BF88 new_v_union;
uint64_t old_v, new_v;
dword_addr.bf88_a = p_dst;
// Read the original 64 bits
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
new_v = new_v_union.u64;
// 64-bit CAS loop
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
"wrong! not implemented");
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
"The granularity of the thread buffer is unsupported on the hardware!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
}
else if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
x.template get_as<bf16x2_t>()[I1]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
}
else if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
x.template get_as<bf16x4_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, bf8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
}

View File

@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"

View File

@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}

View File

@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -77,6 +77,7 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC(),
GetVectorSizeC<ODataType>(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();

View File

@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
@@ -240,7 +240,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
@@ -255,7 +255,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{});
}
else
@@ -519,7 +519,7 @@ struct GemmKernel
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(