mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Mx fp6 flatmm (#3601)
* add fp6 data-type and support sync/async dwordx3 load/store * clang-format * pre-commit * 1st commit * default mnk pass ut * fix a distrubution * fix * fix bdram distr * update * pass ut * improve perf * update * clean code * resolve copilot comment * reslove comment * clang-format --------- Co-authored-by: ZheWang <zhewan@amd.com>
This commit is contained in:
@@ -54,6 +54,7 @@
|
||||
#include "ck_tile/core/numeric/null_type.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp6.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
@@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<int8_t, N>;
|
||||
@@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 12)
|
||||
{
|
||||
auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
|
||||
@@ -1134,6 +1134,25 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32");
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3(dwordx3_union vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset)
|
||||
{
|
||||
int32x3_t v_reg;
|
||||
v_reg[0] = vdata.as_i32[0];
|
||||
v_reg[1] = vdata.as_i32[1];
|
||||
v_reg[2] = vdata.as_i32[2];
|
||||
llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, 0);
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -1290,7 +1309,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<int8_t, N>;
|
||||
@@ -1330,6 +1349,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 12)
|
||||
{
|
||||
auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
dwordx3_union ret;
|
||||
ret.as_i32[0] = tmp[0];
|
||||
ret.as_i32[1] = tmp[1];
|
||||
ret.as_i32[2] = tmp[2];
|
||||
return bit_cast<rtn_type>(ret);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
@@ -1411,15 +1442,19 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
|
||||
(std::is_same<T, uint8_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
|
||||
(std::is_same<T, e8m0_bexp_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_fp4_raw_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, pk_fp6x16_t>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -1750,7 +1785,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(N == 1)
|
||||
@@ -1786,6 +1821,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 12)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x3(bit_cast<dwordx3_union>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
|
||||
@@ -1859,10 +1901,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
|
||||
(std::is_same<T, uint16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(std::is_same<T, uint8_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
std::is_same<T, pk_fp6x16_t>::value && (N == 1),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(std::is_same<T, float>::value) // fp32
|
||||
|
||||
109
include/ck_tile/core/numeric/pk_fp6.hpp
Normal file
109
include/ck_tile/core/numeric/pk_fp6.hpp
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <index_t pk_size>
|
||||
struct pk_fp6_t
|
||||
{
|
||||
static constexpr index_t num_bits_elem = 6;
|
||||
using element_type = int32_t; // element storage fundamental type
|
||||
static constexpr index_t packed_size = pk_size;
|
||||
static constexpr index_t num_bits_vec_elem =
|
||||
sizeof(element_type) * 8; // 32-bit uint for storage
|
||||
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
|
||||
"Packed elements must fit exactly into the element storage.");
|
||||
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
|
||||
element_type data_[vector_size]; // packed data
|
||||
using type = pk_fp6_t<packed_size>;
|
||||
CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0)
|
||||
{
|
||||
for(size_t i = 0; i < vector_size; ++i)
|
||||
{
|
||||
data_[i] = value;
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE void pack(const int32_t x, const index_t i)
|
||||
{
|
||||
int32_t bits = static_cast<int32_t>(x) & 0x3F;
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_index = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
int32_t old_value = data_[arr_index];
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
data_[arr_index] = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
int32_t next_value = data_[arr_index + 1];
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
data_[arr_index + 1] = next_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static int32_t unpack(const T& pk, const index_t i)
|
||||
{
|
||||
const int bit_pos = i * num_bits_elem;
|
||||
const int arr_idx = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
int32_t bits = pk.data_[arr_idx] >> bit_offset;
|
||||
if(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return bits & 0x3F;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); }
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t operator[](index_t i) const { return data_[i]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static float fp6_e2m3_to_float(int32_t fp6_bits)
|
||||
{
|
||||
fp6_bits = fp6_bits & 0x3F;
|
||||
|
||||
uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5
|
||||
uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3
|
||||
uint32_t mantissa = fp6_bits & 0x7; // bits 2-0
|
||||
|
||||
float result;
|
||||
if(exponent == 0 && mantissa == 0)
|
||||
{
|
||||
result = 0.f;
|
||||
}
|
||||
else if(exponent != 0)
|
||||
{
|
||||
result = std::exp2f(static_cast<int>(exponent) - 1);
|
||||
float mantissa_value = 1.0f + mantissa / 8.0f;
|
||||
result *= mantissa_value;
|
||||
}
|
||||
else
|
||||
{
|
||||
result = mantissa / 8.0f;
|
||||
}
|
||||
return sign == 1 ? -1 * result : result;
|
||||
}
|
||||
};
|
||||
|
||||
using pk_fp6x16_t = pk_fp6_t<16>;
|
||||
using pk_fp6x32_t = pk_fp6_t<32>;
|
||||
template <>
|
||||
struct numeric_traits<pk_fp6x16_t>
|
||||
{
|
||||
static constexpr int PackedSize = 16;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -72,6 +72,7 @@ CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp6.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -160,6 +160,40 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
struct int32x3_tt
|
||||
{
|
||||
int32_t data[3];
|
||||
};
|
||||
|
||||
struct int32x6_tt
|
||||
{
|
||||
int32_t data[6];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct impl::ext_vector<int8_t, 12>
|
||||
{
|
||||
static constexpr index_t N = 12;
|
||||
using value_type = int32x3_tt;
|
||||
using type = int32x3_tt;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct impl::ext_vector<pk_fp6x16_t, 1>
|
||||
{
|
||||
static constexpr index_t N = 1;
|
||||
using value_type = int32x3_tt;
|
||||
using type = int32x3_tt;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct impl::ext_vector<pk_fp6x16_t, 2>
|
||||
{
|
||||
static constexpr index_t N = 2;
|
||||
using value_type = int32x6_tt;
|
||||
using type = int32x6_tt;
|
||||
};
|
||||
|
||||
// u32
|
||||
// using uint32_t = ...
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
@@ -303,7 +303,6 @@ struct buffer_view<address_space_enum::global,
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
@@ -825,11 +824,23 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
// using buf_t = ushort __attribute__((ext_vector_type(8)));
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
|
||||
return bit_cast<X>(rtn);
|
||||
constexpr index_t load_elts = scalar_per_t_vector * scalar_per_x_vector;
|
||||
if constexpr(load_elts == 12 && sizeof(typename X::value_type) == 1)
|
||||
{
|
||||
auto rtn = reinterpret_cast<const int32_t*>(p_data_) + (i + linear_offset) / 4;
|
||||
struct
|
||||
{
|
||||
int32_t x, y, z;
|
||||
} tmp = {rtn[0], rtn[1], rtn[2]};
|
||||
return bit_cast<X>(tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
|
||||
return bit_cast<X>(rtn);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -968,6 +979,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 12>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
@@ -1033,6 +1045,11 @@ struct buffer_view<address_space_enum::lds,
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 12>>)
|
||||
{
|
||||
*c_style_pointer_cast<dwordx3_union*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const dwordx3_union*>(&x);
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
@@ -1075,6 +1092,12 @@ struct buffer_view<address_space_enum::lds,
|
||||
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false,
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -720,4 +720,57 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
return err_count == 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between pk_fp6x16_t ranges
|
||||
*
|
||||
* Compares two ranges of pk_fp6x16_t without tolerance.
|
||||
* This specialization handles ck_tile::pk_fp6x16_t type.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, pk_fp6x16_t>),
|
||||
bool>
|
||||
CK_TILE_HOST check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double = 0,
|
||||
double = 0)
|
||||
{
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
|
||||
int err_count = 0;
|
||||
float max_err = 0.0f;
|
||||
auto update_err = [&](float o, float r, std::size_t index) {
|
||||
if(std::fabs(o - r) > 1e-8)
|
||||
{
|
||||
std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o
|
||||
<< " != " << r << std::endl;
|
||||
++err_count;
|
||||
max_err = max_err < std::fabs(o - r) ? o : max_err;
|
||||
}
|
||||
};
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const pk_fp6x16_t o = *std::next(std::begin(out), i);
|
||||
const pk_fp6x16_t r = *std::next(std::begin(ref), i);
|
||||
for(std::size_t j = 0; j < numeric_traits<pk_fp6x16_t>::PackedSize; j++)
|
||||
{
|
||||
update_err(o.unpack(j), r.unpack(j), i * numeric_traits<pk_fp6x16_t>::PackedSize + j);
|
||||
}
|
||||
}
|
||||
if(err_count > 0)
|
||||
{
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return err_count == 0;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -625,6 +625,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
{
|
||||
if(k % pk_fp6x16_t::packed_size != 0)
|
||||
continue;
|
||||
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
|
||||
for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++)
|
||||
{
|
||||
a_m_k_scaled(m, k + k_) =
|
||||
pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
@@ -653,6 +664,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
{
|
||||
if(k % pk_fp6x16_t::packed_size != 0)
|
||||
continue;
|
||||
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++)
|
||||
{
|
||||
b_k_n_scaled(k + k_, n) =
|
||||
pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
|
||||
@@ -22,6 +22,7 @@ template <> struct DataTypeTraits<bf8_t> { static constexpr const char * name =
|
||||
template <> struct DataTypeTraits<int8_t> { static constexpr const char * name = "int8"; };
|
||||
template <> struct DataTypeTraits<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
|
||||
template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
|
||||
template <> struct DataTypeTraits<pk_fp6x16_t> { static constexpr const char * name = "pk_fp6x16"; };
|
||||
template <> struct DataTypeTraits<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };
|
||||
|
||||
template <memory_operation_enum MemOp> struct memOpToStr;
|
||||
|
||||
@@ -118,8 +118,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp / BPackedSize;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
static constexpr index_t KFlatBytesPerBlockPerIter =
|
||||
flatKPerWarp * sizeof(BDataType) / BPackedSize;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
@@ -132,8 +133,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = 16 /*dwordx4*/ * APackedSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType);
|
||||
static constexpr index_t AK1 = std::is_same_v<ADataType, pk_fp6x16_t>
|
||||
? 16
|
||||
: 16 /*dwordx4*/ * APackedSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = std::is_same_v<BDataType, pk_fp6x16_t>
|
||||
? 16
|
||||
: 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType);
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
@@ -537,24 +542,26 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
auto a_store_lds_window_ping = make_tile_window( //
|
||||
a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / APackedSize>{}),
|
||||
make_tuple(number<kMPerBlock>{},
|
||||
number<kKPerBlock / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0});
|
||||
auto a_store_lds_window_pong = make_tile_window( //
|
||||
a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / APackedSize>{}),
|
||||
make_tuple(number<kMPerBlock>{},
|
||||
number<kKPerBlock / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0});
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
auto a_warp_window_ping = make_tile_window(
|
||||
a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong = make_tile_window(
|
||||
a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
|
||||
// B flat DRAM window for load
|
||||
|
||||
@@ -621,7 +628,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock * sizeof(ADataType) / APackedSize});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -663,7 +670,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(HasHotLoop || TailNum == TailNumber::Even)
|
||||
{
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, sizeof(ADataType) * kKPerBlock / APackedSize});
|
||||
}
|
||||
// initialize C
|
||||
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp>
|
||||
@@ -683,7 +690,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -750,7 +758,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset( //
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
@@ -760,7 +768,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock * sizeof(ADataType) / APackedSize});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -772,7 +780,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -839,7 +848,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset( //
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
// barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished
|
||||
@@ -849,7 +858,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, sizeof(ADataType) * kKPerBlock / APackedSize});
|
||||
// move B window to next flat K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -860,7 +869,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
};
|
||||
@@ -874,7 +884,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
iCounter--;
|
||||
} while(iCounter > 0);
|
||||
}
|
||||
|
||||
// TAIL
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
@@ -933,7 +942,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset( //
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
@@ -947,7 +956,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
@@ -977,12 +987,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(n_iter == NIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) =
|
||||
load_tile_with_offset(a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
@@ -1014,12 +1024,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(n_iter == NIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) =
|
||||
load_tile_with_offset(a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
|
||||
@@ -17,6 +17,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
static constexpr index_t DWORDx4 = 16;
|
||||
static constexpr index_t DWORDx3 = 12;
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
@@ -77,15 +78,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution()
|
||||
{
|
||||
constexpr index_t K2 = DWORDx4; // 16 bytes
|
||||
constexpr index_t K1 = kDramLoadPackBytes / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize
|
||||
constexpr index_t K2 = std::is_same_v<ADataType, pk_fp6x16_t> ? DWORDx3 : DWORDx4;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8
|
||||
constexpr index_t K0 =
|
||||
KPerBlock / APackedSize * sizeof(ADataType) / (K1 * K2); // KPerBlock/256/packsize
|
||||
|
||||
constexpr index_t M2 = WaveSize / K1; // 8
|
||||
constexpr index_t M1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
|
||||
static_assert(K0 * K1 * K2 == KPerBlock / APackedSize * sizeof(ADataType),
|
||||
"K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
@@ -107,9 +109,9 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
|
||||
constexpr index_t K2 = DWORDx4; // 16 bytes
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
const index_t K0 = cols / (K1 * K2 * APackedSize);
|
||||
constexpr index_t K2 = std::is_same_v<ADataType, pk_fp6x16_t> ? DWORDx3 : DWORDx4;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8
|
||||
const index_t K0 = cols / (K1 * K2 / sizeof(ADataType) * APackedSize);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
@@ -138,19 +140,23 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view = make_tensor_view<address_space_enum::global>(byte_ptr, desc);
|
||||
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
constexpr index_t test1 = APackedSize / sizeof(ADataType);
|
||||
return make_tile_window(byte_tensor_view,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / APackedSize},
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / test1>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / test1},
|
||||
MakeMX_ABytesDramTileDistribution());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = AK1 / APackedSize; // 16
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
|
||||
constexpr index_t K2 = std::is_same_v<ADataType, pk_fp6x16_t> ? DWORDx3 : AK1 / APackedSize;
|
||||
constexpr index_t K2_Pad = 16;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
constexpr index_t K0 = std::is_same_v<ADataType, pk_fp6x16_t>
|
||||
? KPerBlock / (K1 * K2 / sizeof(ADataType) * APackedSize)
|
||||
: KPerBlock / (K1 * AK1); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 / sizeof(ADataType) * APackedSize == KPerBlock,
|
||||
"K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
|
||||
@@ -169,12 +175,12 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
number<M3>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2) + (M1 - 1) * Pad)>{},
|
||||
number<M1*(M2 * M3 * K1 * K2) + (M1 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2 + Pad>{},
|
||||
number<M3 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2_Pad) + (M1 - 1) * Pad)>{},
|
||||
number<M1*(M2 * M3 * K1 * K2_Pad) + (M1 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2_Pad + Pad>{},
|
||||
number<M3 * K1 * K2_Pad>{},
|
||||
number<K1 * K2_Pad>{},
|
||||
number<K2_Pad>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
@@ -216,7 +222,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
if constexpr(K_Thread == AK1)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<NWarps>,
|
||||
@@ -225,7 +231,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
else
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarps>,
|
||||
@@ -235,6 +241,19 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{});
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
// K_Lane=4, K_Thread=32
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<NWarps>,
|
||||
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
|
||||
sequence<K_Lane, KPerXdl / (K_Lane * APackedSize), DWORDx3>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<1, 2>>{});
|
||||
else
|
||||
static_assert(false, "unsupported datatype");
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
|
||||
@@ -245,17 +264,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
if constexpr(BK1 == K_Thread)
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWarps, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 32
|
||||
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
else
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
@@ -265,6 +284,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>{});
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWarps, NXdlPack>, // 4 2
|
||||
sequence<K0,
|
||||
K1,
|
||||
K_Thread * sizeof(BDataType) / (DWORDx3 * BPackedSize),
|
||||
DWORDx3>>, // 64 1 2 12
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2, 2>,
|
||||
sequence<2, 3>>{});
|
||||
else
|
||||
static_assert(false, "unsupported datatype");
|
||||
}
|
||||
|
||||
template <typename WindowTmp>
|
||||
@@ -280,21 +314,27 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile;
|
||||
auto&& byte_tensor_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(flat_n,
|
||||
flat_k / flat_k_per_block,
|
||||
number<flat_k_per_block / BPackedSize * sizeof(BDataType)>{})),
|
||||
make_tuple(make_pass_through_transform(flat_n),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
|
||||
flat_k / flat_k_per_block,
|
||||
number<flat_k_per_block / BPackedSize * sizeof(BDataType)>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
auto origin_n = origin_tmp[0];
|
||||
auto origin_k = static_cast<int>(origin_tmp[1] * sizeof(BDataType) / BPackedSize);
|
||||
return make_tile_window(
|
||||
byte_tensor_view,
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / BPackedSize},
|
||||
make_tuple(number<flatNPerWarp>{},
|
||||
number<flatKPerWarp * sizeof(BDataType) / BPackedSize>{}),
|
||||
{origin_n, origin_k},
|
||||
MakeMX_BFlatBytesDramTileDistribution());
|
||||
}
|
||||
|
||||
@@ -372,7 +412,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size();
|
||||
if constexpr(!std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
{
|
||||
return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeMX_ALdsBytesBlockDescriptor().get_element_space_size();
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); }
|
||||
|
||||
@@ -1614,7 +1614,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
|
||||
return make_tuple(number<0>{}, int32x8_t{});
|
||||
else if constexpr(std::is_same_v<decltype(dtype), bf8_t>)
|
||||
return make_tuple(number<1>{}, int32x8_t{});
|
||||
// else if e2m3 => make_tuple(number<2>{}, int32x6_t{})
|
||||
else if constexpr(std::is_same_v<decltype(dtype), pk_fp6x16_t>)
|
||||
return make_tuple(number<2>{}, pk_fp6x32_t{});
|
||||
// else if e3m2 => make_tuple(number<3>{}, int32x6_t{})
|
||||
else if constexpr(std::is_same_v<decltype(dtype), pk_fp4_t>)
|
||||
return make_tuple(number<4>{}, int32x4_t{});
|
||||
|
||||
Reference in New Issue
Block a user