This commit is contained in:
Feng Shijie
2025-08-11 11:24:34 +00:00
parent 200a11afc8
commit edb58d0680
7 changed files with 112 additions and 50 deletions

View File

@@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop>
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
@@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop>
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
@@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop>
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
@@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop>
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
@@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop>
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
@@ -624,7 +624,7 @@ struct buffer_store_if<16>
{
static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
@@ -681,7 +681,7 @@ struct buffer_store_if<4>
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
@@ -709,7 +709,7 @@ struct buffer_store_if<2>
{
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short;
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
@@ -737,7 +737,7 @@ struct buffer_store_if<1>
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
@@ -1448,7 +1448,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
((std::is_same<T, int8_t>::value || std::is_same<T, e8m0_t>::value) &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -90,7 +91,8 @@ struct vector_traits
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t>,
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
uint8_t,
remove_cvref_t<T>>>;
@@ -101,10 +103,12 @@ struct vector_traits
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
using scalar_type =
std::conditional_t<std::is_same_v<T, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<T, pk_fp4_t>, uint8_t, T>>;
using scalar_type = std::conditional_t<
std::is_same_v<T, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
uint8_t,
T>>;
static constexpr index_t vector_size = N;
};

View File

@@ -256,7 +256,7 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0};
remove_cvref_t<T> invalid_element_value_ = T{0.f};
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
@@ -269,7 +269,7 @@ struct buffer_view<address_space_enum::global,
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
invalid_element_value_{0}
invalid_element_value_{0.f}
{
}
@@ -657,7 +657,7 @@ struct buffer_view<address_space_enum::global,
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
#elif (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);

View File

@@ -157,12 +157,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view =
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
@@ -297,7 +297,11 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_block_window);
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -326,7 +330,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I3);
const auto& scale_block_window = gemm_tile_windows.at(I4);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
b_flat_block_window,
scale_block_window,

View File

@@ -588,8 +588,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
constexpr int ScaleB_BlockK =
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
constexpr int ScaleB_BlockK = 16 * 2 * 4;
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
auto scale_b_flat_dram_window = make_tile_window(
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
@@ -640,8 +640,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
@@ -690,6 +691,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
#if defined(__gfx942__)
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
lane_scale);
return lane_scale;
#endif
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
@@ -705,12 +708,13 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
if constexpr(xdl_k_idx % 2 == 0)
{
return v2scale[0];
lane_scale = v2scale[0];
}
else
{
return v2scale[1];
lane_scale = v2scale[1];
}
return lane_scale;
};
auto deq_fn = [&](const auto& quant_weight_tensor,
@@ -721,15 +725,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
packed_scale = perm_scale(packed_scale, b_idx_k);
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
auto use_scale = scale;
use_scale.data = perm_scale(scale.data, b_idx_k);
if constexpr(xdl_nIter % 2 != 0)
{
scale_ptr++;
}
if constexpr(xdl_nIter == 0)
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
{
printf("laneid = %2u xdl-k=%2d use-scale = "
"%.2f\n",
threadIdx.x,
int(xdl_kIter),
float(use_scale));
}
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
@@ -737,7 +746,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
number<i>{},
pk_fp4_to_fp16x2(
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
*scale_ptr));
static_cast<float>(use_scale)));
});
};
@@ -748,6 +757,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B(2i+1)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -828,6 +851,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B(2i+2)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -910,6 +947,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B(loopK)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),