fix after merge ginolu/add_wgmfma_dispatcher

This commit is contained in:
mtgu0705
2025-09-09 04:37:42 -05:00
parent f119c30317
commit b0d71b8d19
9 changed files with 1037 additions and 339 deletions

View File

@@ -107,36 +107,36 @@ CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
template <typename X>
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, const fp16x2_t& x)
{
union U32FP162_ADDR
{
uint32_t* u32_a;
fp16x2_t* fp162_a;
};
// template <>
// CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, const fp16x2_t& x)
// {
// union U32FP162_ADDR
// {
// uint32_t* u32_a;
// fp16x2_t* fp162_a;
// };
union U32FP162
{
uint32_t u32;
fp16x2_t fp162;
};
// union U32FP162
// {
// uint32_t u32;
// fp16x2_t fp162;
// };
U32FP162_ADDR dword_addr;
U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
// U32FP162_ADDR dword_addr;
// U32FP162 cur_v;
// U32FP162 new_;
// uint32_t old_v, new_v;
// dword_addr.fp162_a = p_dst;
// cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
// do
// {
// old_v = cur_v.u32;
// new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
// new_v = new_.u32;
// cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
// } while(cur_v.u32 != old_v);
// }
template <>
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)

View File

@@ -31,10 +31,9 @@ struct e8m0_bexp_t
raw_type data;
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) : data(0)
{
data = numeric_utils<float>::get_exponent(scale);
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }