mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
update
This commit is contained in:
@@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
static_assert(sizeof(mbuf_t) == sizeof(T));
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
@@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -624,7 +624,7 @@ struct buffer_store_if<16>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = fp32x4_t;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -681,7 +681,7 @@ struct buffer_store_if<4>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -709,7 +709,7 @@ struct buffer_store_if<2>
|
||||
{
|
||||
static_assert(sizeof(T) == 2);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = short;
|
||||
using mbuf_t = short;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -737,7 +737,7 @@ struct buffer_store_if<1>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -1448,7 +1448,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
((std::is_same<T, int8_t>::value || std::is_same<T, e8m0_t>::value) &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -90,7 +91,8 @@ struct vector_traits
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t>,
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
remove_cvref_t<T>>>;
|
||||
|
||||
@@ -101,10 +103,12 @@ struct vector_traits
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<T, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<T, pk_fp4_t>, uint8_t, T>>;
|
||||
using scalar_type = std::conditional_t<
|
||||
std::is_same_v<T, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
T>>;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
@@ -256,7 +256,7 @@ struct buffer_view<address_space_enum::global,
|
||||
T* p_data_ = nullptr;
|
||||
BufferSizeType buffer_size_;
|
||||
int32x4_t cached_buf_res_;
|
||||
remove_cvref_t<T> invalid_element_value_ = T{0};
|
||||
remove_cvref_t<T> invalid_element_value_ = T{0.f};
|
||||
|
||||
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
|
||||
@@ -269,7 +269,7 @@ struct buffer_view<address_space_enum::global,
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size / PackedSize},
|
||||
cached_buf_res_{0},
|
||||
invalid_element_value_{0}
|
||||
invalid_element_value_{0.f}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -657,7 +657,7 @@ struct buffer_view<address_space_enum::global,
|
||||
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
#elif (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
|
||||
@@ -157,12 +157,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
@@ -297,7 +297,11 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -326,7 +330,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
|
||||
@@ -588,8 +588,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
constexpr int ScaleB_BlockK =
|
||||
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
constexpr int ScaleB_BlockK = 16 * 2 * 4;
|
||||
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
@@ -640,8 +640,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
@@ -690,6 +691,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
|
||||
#if defined(__gfx942__)
|
||||
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
|
||||
lane_scale);
|
||||
return lane_scale;
|
||||
#endif
|
||||
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
|
||||
@@ -705,12 +708,13 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx % 2 == 0)
|
||||
{
|
||||
return v2scale[0];
|
||||
lane_scale = v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return v2scale[1];
|
||||
lane_scale = v2scale[1];
|
||||
}
|
||||
return lane_scale;
|
||||
};
|
||||
|
||||
auto deq_fn = [&](const auto& quant_weight_tensor,
|
||||
@@ -721,15 +725,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
|
||||
|
||||
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
|
||||
packed_scale = perm_scale(packed_scale, b_idx_k);
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
|
||||
|
||||
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
|
||||
auto use_scale = scale;
|
||||
use_scale.data = perm_scale(scale.data, b_idx_k);
|
||||
|
||||
if constexpr(xdl_nIter % 2 != 0)
|
||||
{
|
||||
scale_ptr++;
|
||||
}
|
||||
if constexpr(xdl_nIter == 0)
|
||||
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
|
||||
{
|
||||
printf("laneid = %2u xdl-k=%2d use-scale = "
|
||||
"%.2f\n",
|
||||
threadIdx.x,
|
||||
int(xdl_kIter),
|
||||
float(use_scale));
|
||||
}
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
|
||||
@@ -737,7 +746,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
number<i>{},
|
||||
pk_fp4_to_fp16x2(
|
||||
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
|
||||
*scale_ptr));
|
||||
static_cast<float>(use_scale)));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -748,6 +757,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -828,6 +851,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -910,6 +947,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(loopK)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
|
||||
Reference in New Issue
Block a user