mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#5849 (commit d9b89b2)
[CK_TILE ]Revert "[CK_TILE] Enable MXFP6 for MX GEMM op (#5095)" (#5849) This reverts commit 7e55766ddf7e9e20791b0e4e2d7b4026cf16b637. ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c28d0033d7
commit
3b55a05e71
@@ -22,10 +22,7 @@ struct pk_fp6_t
|
||||
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
|
||||
element_type data_[vector_size]; // packed data
|
||||
using type = pk_fp6_t<packed_size>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp6_t() : data_{element_type{}} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value)
|
||||
CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0)
|
||||
{
|
||||
for(size_t i = 0; i < vector_size; ++i)
|
||||
{
|
||||
@@ -62,14 +59,13 @@ struct pk_fp6_t
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
uint32_t bits = static_cast<uint32_t>(pk.data_[arr_idx]) >> bit_offset;
|
||||
int32_t bits = pk.data_[arr_idx] >> bit_offset;
|
||||
if(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (static_cast<uint32_t>(pk.data_[arr_idx + 1]) & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<int32_t>(bits & 0x3F);
|
||||
return bits & 0x3F;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); }
|
||||
@@ -101,22 +97,6 @@ struct pk_fp6_t
|
||||
}
|
||||
return sign == 1 ? -1 * result : result;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static int32_t float_to_fp6_e2m3(float val)
|
||||
{
|
||||
int32_t best = 0;
|
||||
float best_err = 1e30f;
|
||||
for(int32_t i = 0; i < 64; i++)
|
||||
{
|
||||
float err = std::fabs(val - fp6_e2m3_to_float(i));
|
||||
if(err < best_err)
|
||||
{
|
||||
best = i;
|
||||
best_err = err;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
using pk_fp6x16_t = pk_fp6_t<16>;
|
||||
@@ -125,7 +105,5 @@ template <>
|
||||
struct numeric_traits<pk_fp6x16_t>
|
||||
{
|
||||
static constexpr int PackedSize = 16;
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 3;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -19,14 +19,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// buffer_load_dwordx3 to LDS uses a fixed 16-byte per-thread stride,
|
||||
// padding each 12-byte element to 16 bytes in LDS.
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t lds_padded_sizeof()
|
||||
{
|
||||
return (sizeof(T) == 12) ? 16 : sizeof(T);
|
||||
}
|
||||
|
||||
// T may be scalar or vector
|
||||
// X may be scalar or vector
|
||||
// T and X have same scalar type
|
||||
@@ -848,10 +840,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
{
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
constexpr index_t padded_stride = lds_padded_sizeof<T>();
|
||||
const char* base =
|
||||
reinterpret_cast<const char*>(p_data_) + (i + linear_offset) * padded_stride;
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(base);
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
|
||||
return bit_cast<X>(rtn);
|
||||
}
|
||||
#endif
|
||||
@@ -883,8 +872,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
bool /*is_valid_element*/,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
constexpr index_t padded_stride = lds_padded_sizeof<T>();
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * padded_stride, i_offset * padded_stride);
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
|
||||
@@ -631,24 +631,21 @@ struct tile_scatter_gather
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// buffer load with dwordx3 requires 128-bit alignment
|
||||
constexpr index_t lds_stride = lds_padded_sizeof<LdsDataType>();
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
lds_stride;
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
lds_stride -
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
lds_stride -
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
@@ -783,12 +780,9 @@ struct tile_scatter_gather
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
// Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride)
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR LdsDataType*>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR char*>(lds_base_ptr) +
|
||||
(lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize *
|
||||
lds_padded_sizeof<LdsDataType>());
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
|
||||
lds_coord.get_offset() / Traits::PackedSize +
|
||||
lds_ys_offset / Traits::PackedSize;
|
||||
|
||||
const auto dram_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
|
||||
@@ -501,23 +501,21 @@ struct tile_window_with_static_distribution
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
constexpr index_t lds_stride = lds_padded_sizeof<LdsDataType>();
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
lds_stride;
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
lds_stride -
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
lds_stride -
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
// Use VALU so the compiler can optimize redundant/repeated computations
|
||||
@@ -630,12 +628,8 @@ struct tile_window_with_static_distribution
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
// Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride)
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR LdsDataType*>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR char*>(lds_base_ptr) +
|
||||
(lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize *
|
||||
lds_padded_sizeof<LdsDataType>());
|
||||
lds_base_ptr + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize;
|
||||
|
||||
const auto dram_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
|
||||
@@ -61,7 +61,6 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
|
||||
tf32_t,
|
||||
pk_fp4_t,
|
||||
pk_fp4_raw_t,
|
||||
pk_fp6x16_t,
|
||||
pk_int4_t,
|
||||
I8,
|
||||
I32,
|
||||
@@ -136,7 +135,6 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
|
||||
tf32_t,
|
||||
pk_fp4_t,
|
||||
pk_fp4_raw_t,
|
||||
pk_fp6x16_t,
|
||||
pk_int4_t,
|
||||
I8,
|
||||
I32,
|
||||
|
||||
@@ -169,41 +169,6 @@ struct FillUniformDistribution<ck_tile::pk_int4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FillUniformDistribution<ck_tile::pk_fp6x16_t>
|
||||
{
|
||||
float a_{-2.f};
|
||||
float b_{2.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
while(first != last)
|
||||
{
|
||||
ck_tile::pk_fp6x16_t pk{};
|
||||
for(ck_tile::index_t i = 0; i < ck_tile::pk_fp6x16_t::packed_size; ++i)
|
||||
{
|
||||
pk.pack(ck_tile::pk_fp6x16_t::float_to_fp6_e2m3(dis(gen)), i);
|
||||
}
|
||||
*first = pk;
|
||||
++first;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
// clang-format off
|
||||
|
||||
@@ -146,10 +146,9 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
|
||||
constexpr index_t a_lds_block_space_size = lds_padded_sizeof<OverrideADataType>() *
|
||||
a_lds_block_desc.get_element_space_size() /
|
||||
APackedSize;
|
||||
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
|
||||
constexpr index_t a_lds_block_space_size =
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_least_multiple(a_lds_block_space_size, 16);
|
||||
|
||||
|
||||
@@ -837,10 +837,9 @@ struct UniversalGemmBasePolicy
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A)) *
|
||||
numeric_traits<A>::PackedSize;
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
|
||||
return ck_tile::min(KPack, VecElems);
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -850,10 +849,9 @@ struct UniversalGemmBasePolicy
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B)) *
|
||||
numeric_traits<B>::PackedSize;
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
|
||||
|
||||
return ck_tile::min(KPack, VecElems);
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -862,10 +860,8 @@ struct UniversalGemmBasePolicy
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a =
|
||||
integer_least_multiple(a_lds_block_desc.get_element_space_size() *
|
||||
lds_padded_sizeof<ADataType>() / APackedSize,
|
||||
16);
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
@@ -878,10 +874,8 @@ struct UniversalGemmBasePolicy
|
||||
typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b =
|
||||
integer_least_multiple(b_lds_block_desc.get_element_space_size() *
|
||||
lds_padded_sizeof<BDataType>() / BPackedSize,
|
||||
16);
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
|
||||
@@ -442,12 +442,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
MWarp / BlockSize,
|
||||
"BLdsTile size is wrong!");
|
||||
static_assert(Policy::template GetSmemSizeA<Problem>() ==
|
||||
MPerBlock *
|
||||
(KPerBlock * lds_padded_sizeof<ADataType>() / APackedSize),
|
||||
MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize),
|
||||
"SmemSizeA size is wrong!");
|
||||
static_assert(Policy::template GetSmemSizeB<Problem>() ==
|
||||
(KPerBlock * lds_padded_sizeof<BDataType>() / BPackedSize) *
|
||||
NPerBlock,
|
||||
(KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock,
|
||||
"SmemSizeB size is wrong!");
|
||||
|
||||
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
|
||||
|
||||
Reference in New Issue
Block a user