mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)
* Support bf16/fb8/bf8 datatypes for ck_tile/gemm
* remove commented out code.
* Addressing code review comments and enabling universal_gemm for all the supported data types.
* Merge conflict resolution.
* Solve the memory pipeline compilation error. Merge with the new change of CShuffle
* finish the feature, pass the tests
* Fix the pipeline and add the benchmark script for other data types
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
[ROCm/composable_kernel commit: ab5d027866]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -8,16 +8,75 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
|
||||
template <typename T, typename ComputeType>
|
||||
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
|
||||
{
|
||||
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
|
||||
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
rtn[0] = add_bf16_t(a[0], b[0]);
|
||||
rtn[1] = add_bf16_t(a[1], b[1]);
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
{
|
||||
bf16x4_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf16_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf16_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
|
||||
{
|
||||
fp8x8_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<fp8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<fp8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<fp8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<fp8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
|
||||
{
|
||||
bf8x4_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
|
||||
{
|
||||
bf8x8_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<bf8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<bf8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<bf8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<bf8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
|
||||
{
|
||||
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
|
||||
union U64BF164_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf16x4_t* bf164_a;
|
||||
};
|
||||
|
||||
// Union to treat the data as either bf16x4_t or 64-bit integer
|
||||
union U64BF164
|
||||
{
|
||||
uint64_t u64;
|
||||
bf16x4_t bf164;
|
||||
};
|
||||
|
||||
U64BF164_ADDR addr;
|
||||
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
|
||||
|
||||
// First read (non-atomic) of the old value
|
||||
U64BF164 cur_v;
|
||||
cur_v.u64 = *addr.u64_a;
|
||||
|
||||
U64BF164 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
do
|
||||
{
|
||||
// old 64 bits
|
||||
old_v = cur_v.u64;
|
||||
|
||||
// Add elementwise in bf16
|
||||
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt the 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
|
||||
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
|
||||
{
|
||||
union U32FP84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp8x4_t* fp84_a;
|
||||
};
|
||||
|
||||
union U32FP84
|
||||
{
|
||||
uint32_t u32;
|
||||
fp8x4_t fp84;
|
||||
};
|
||||
|
||||
U32FP84_ADDR dword_addr;
|
||||
U32FP84 cur_v;
|
||||
U32FP84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.fp84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
|
||||
{
|
||||
union U32BF84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf8x4_t* bf84_a;
|
||||
};
|
||||
|
||||
union U32BF84
|
||||
{
|
||||
uint32_t u32;
|
||||
bf8x4_t bf84;
|
||||
};
|
||||
|
||||
U32BF84_ADDR dword_addr;
|
||||
U32BF84 cur_v;
|
||||
U32BF84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.bf84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
|
||||
{
|
||||
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
|
||||
union U64FP88_ADDR
|
||||
{
|
||||
uint64_t* u64_a; // pointer to 64-bit integer
|
||||
fp8x8_t* fp88_a; // pointer to fp8x8_t
|
||||
};
|
||||
|
||||
union U64FP88
|
||||
{
|
||||
uint64_t u64;
|
||||
fp8x8_t fp88;
|
||||
};
|
||||
|
||||
U64FP88_ADDR dword_addr;
|
||||
U64FP88 cur_v;
|
||||
U64FP88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
// Point to the destination as both fp8x8_t* and uint64_t*.
|
||||
dword_addr.fp88_a = p_dst;
|
||||
// Initial read of 64 bits from memory
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each fp8 element using your add_fp8x8_t(...) routine
|
||||
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for bf8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
{
|
||||
union U64BF88_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf8x8_t* bf88_a;
|
||||
};
|
||||
|
||||
union U64BF88
|
||||
{
|
||||
uint64_t u64;
|
||||
bf8x8_t bf88;
|
||||
};
|
||||
|
||||
U64BF88_ADDR dword_addr;
|
||||
U64BF88 cur_v;
|
||||
U64BF88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
dword_addr.bf88_a = p_dst;
|
||||
// Read the original 64 bits
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each bf8 element using your add_bf8x8_t(...) routine
|
||||
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// 64-bit CAS loop
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
"The granularity of the thread buffer is unsupported on the hardware!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x2_t>()[I1]);
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x4_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck_tile/host/reference/reference_batched_masking.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_fused_moe.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
@@ -34,4 +35,3 @@
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
|
||||
|
||||
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
|
||||
double compute_error = 0;
|
||||
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
|
||||
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
|
||||
ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = acc;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ struct CShuffleEpilogue
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
template <typename ODataType>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
constexpr index_t MaxVectorStoreSize = 16;
|
||||
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
GetVectorSizeC(),
|
||||
GetVectorSizeC<ODataType>(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
@@ -240,7 +240,7 @@ struct GemmKernel
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
|
||||
return false;
|
||||
@@ -255,7 +255,7 @@ struct GemmKernel
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
{
|
||||
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
|
||||
return false;
|
||||
@@ -321,7 +321,7 @@ struct GemmKernel
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
@@ -519,7 +519,7 @@ struct GemmKernel
|
||||
{
|
||||
// Do not compile in case where we have unsupported
|
||||
// VectorSizeC & data type configuration.
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
|
||||
Reference in New Issue
Block a user