mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add FP4 MX MFMA tests (#2151)
* Add conversion tests * Fix ctor * Fix nan logic * Fix conversion logic * Permute packed f4_t values * Fix conversion to float, repack vector elements * Fix device tests * Permute elements in a vector * Add a repro test * Add a conversion for a repro test * Update test vectors * Update conversion * Fix the test * Update test vector generator * Fix vector sr conversion * Permute conversion args * Update conversion * Test * Fix packing * Simplify conversion function * Pack conversion in a loop * Pack conversion in a loop * Pack another conversion in a loop * Pack one more conversion in a loop * Pack the last conversion in a loop * Clean up * Add ops * Add tests * Add missing utils * Update reference mx gemm * Add f4x2 init mode * Update host tensor utils * Update chunk size for f4x2 * Add non scaled ops * Add a type utility * Update non scaled reference kernel * Add non scaled tests * Debug mfma arguments * Add more debug info * Update chunk size * Update data layout * Add more debugging * Fix B stride * Fix reference gemm * Fix build * One more reference fix * Add more debug info * Disable some tests * Enable tests * Add fp4 dimensions * Update reference kernels * Temp edits * Remove leftovers * Fix conflicts * Clean up * More clean up * Revert "More clean up" This reverts commitd8d35a0846. * Add layouts to tests --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> [ROCm/composable_kernel commit:8a0d659f92]
This commit is contained in:
@@ -51,7 +51,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
|
||||
{
|
||||
os << ck::type_convert<float>(v);
|
||||
}
|
||||
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
|
||||
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t> ||
|
||||
std::is_same_v<RangeType, ck::f4x2_pk_t>)
|
||||
{
|
||||
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
|
||||
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
|
||||
@@ -359,7 +360,8 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSpaceSize() const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return (mDesc.GetElementSpaceSize() + 1) / 2;
|
||||
}
|
||||
@@ -514,7 +516,8 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
|
||||
}
|
||||
@@ -527,7 +530,8 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
@@ -540,7 +544,8 @@ struct Tensor
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
|
||||
}
|
||||
@@ -552,7 +557,8 @@ struct Tensor
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
@@ -564,7 +570,8 @@ struct Tensor
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -81,6 +81,18 @@ struct GeneratorTensor_1<ck::f4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::f4x2_pk_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{value, value})};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
@@ -209,6 +221,21 @@ struct GeneratorTensor_2<ck::f4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f4x2_pk_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
float tmp0 = (std::rand() % (max_value - min_value)) + min_value;
|
||||
float tmp1 = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{tmp0, tmp1})};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -296,6 +323,25 @@ struct GeneratorTensor_3<ck::f4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f4x2_pk_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
float tmp0 = float(std::rand()) / float(RAND_MAX);
|
||||
float tmp1 = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp0 = min_value + tmp0 * (max_value - min_value);
|
||||
float fp32_tmp1 = min_value + tmp1 * (max_value - min_value);
|
||||
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
|
||||
@@ -508,6 +508,34 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -589,6 +617,40 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f4x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -686,6 +748,39 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f4x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
@@ -748,6 +843,33 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
0,
|
||||
0, // OPSEL
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -470,6 +470,13 @@ struct scalar_type<e8m0_bexp_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<f4x2_pk_t>
|
||||
{
|
||||
using type = f4x2_pk_t::type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bool>
|
||||
{
|
||||
|
||||
@@ -79,6 +79,16 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
i4 = i4 - 8;
|
||||
v_a = type_convert<ComputeTypeA>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
|
||||
else
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
@@ -95,6 +105,16 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
i4 = i4 - 8;
|
||||
v_b = type_convert<ComputeTypeB>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
|
||||
else
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
@@ -89,9 +89,28 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
{
|
||||
for(size_t k = 0; k < K; k++)
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
|
||||
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
else
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
|
||||
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,9 +118,28 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
{
|
||||
for(size_t k = 0; k < K; k++)
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
|
||||
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
else
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
|
||||
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "mx_mfma_op.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::f4_t;
|
||||
using ck::f4x2_pk_t;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
@@ -16,7 +18,7 @@ using ck::type_convert;
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mfma_test(ck::index_t init)
|
||||
bool run_mfma_km_kn_nm_test(ck::index_t init)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -30,7 +32,8 @@ bool run_mfma_test(ck::index_t init)
|
||||
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
|
||||
const auto mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
|
||||
const auto mfma_kernel = ck::
|
||||
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
@@ -52,15 +55,72 @@ bool run_mfma_test(ck::index_t init)
|
||||
|
||||
TEST(MFMA, FP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Run the test for the given MFMA instruction
|
||||
*
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mfma_mk_kn_mn_test(ck::index_t init)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AccType = float; // only MFMA_F32 instructions supported
|
||||
using CPUAccType = AccType;
|
||||
|
||||
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
|
||||
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
|
||||
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
|
||||
const auto mfma_kernel = ck::
|
||||
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
|
||||
AType,
|
||||
BType,
|
||||
CType,
|
||||
AccType,
|
||||
CPUAccType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K>{}(mfma_kernel, init);
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
TEST(MFMA, FP4MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP4MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
@@ -70,7 +130,7 @@ TEST(MFMA, FP8MFMA32x32x64)
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mxmfma_test(ck::index_t init)
|
||||
bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
|
||||
{
|
||||
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
|
||||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
|
||||
@@ -88,8 +148,18 @@ bool run_mxmfma_test(ck::index_t init)
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
constexpr auto BLOCK_X = 32; // scaling vector size
|
||||
|
||||
const auto mx_mfma_kernel =
|
||||
ck::matmul<AType, BType, ScaleType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_X>;
|
||||
const auto mx_mfma_kernel = ck::matmul<AType,
|
||||
BType,
|
||||
ScaleType,
|
||||
CType,
|
||||
AccType,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
BLOCK_X,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
@@ -111,14 +181,34 @@ bool run_mxmfma_test(ck::index_t init)
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
auto AB_init = 5;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
auto AB_init = 5;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 4;
|
||||
auto pass =
|
||||
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(
|
||||
AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -111,7 +112,7 @@ template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
|
||||
__device__ AFragT load_A_col_major(AType const* input_ptr)
|
||||
{
|
||||
// clang-format off
|
||||
// Register Mapping for 16x128: || Register Mapping for 32x64:
|
||||
// Register Mapping for 16x128 for FP8: || Register Mapping for 32x64 for FP8:
|
||||
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
|
||||
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
@@ -176,13 +177,19 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
|
||||
auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M);
|
||||
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M);
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT = vector_type<ARawT, vectorSize(AFragT{})>::type;
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT =
|
||||
vector_type<ARawT,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
|
||||
AScalarFragT fragA{};
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
|
||||
#pragma unroll
|
||||
for(int chunk = 0; chunk < 2; chunk++)
|
||||
for(int chunk = 0; chunk < num_chunks; chunk++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(uint32_t i = 0; i < chunk_size; i++)
|
||||
@@ -241,6 +248,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
|
||||
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
|
||||
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
||||
|
||||
// Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4:
|
||||
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
|
||||
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
||||
// Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] |
|
||||
// Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] |
|
||||
// Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] |
|
||||
// Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] |
|
||||
// Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] |
|
||||
// Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] |
|
||||
// Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] |
|
||||
// Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] |
|
||||
// Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] |
|
||||
// Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] |
|
||||
// Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] |
|
||||
// Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] |
|
||||
// Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] |
|
||||
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
|
||||
// clang-format on
|
||||
|
||||
static constexpr int32_t WAVE_SIZE = 64;
|
||||
@@ -265,23 +294,34 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
||||
|
||||
// BLOCK_K is a stride in A matrix
|
||||
auto startOffset = row_major(startCoord2D, BLOCK_K);
|
||||
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
|
||||
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K);
|
||||
auto startOffset = row_major(
|
||||
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
|
||||
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
auto kMajorOffset =
|
||||
row_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
|
||||
union
|
||||
{
|
||||
AFragT frag;
|
||||
AScalarFragT chunks[2];
|
||||
AScalarFragT chunks[num_chunks];
|
||||
} fragA{};
|
||||
|
||||
auto* fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset);
|
||||
fragA.chunks[0] = *fragPtr;
|
||||
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
|
||||
fragA.chunks[1] = *fragPtr;
|
||||
const AScalarFragT* fragPtr;
|
||||
|
||||
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
|
||||
{
|
||||
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset +
|
||||
chunk_idx * kMajorOffset);
|
||||
fragA.chunks[chunk_idx] = *fragPtr;
|
||||
}
|
||||
|
||||
return fragA.frag;
|
||||
}
|
||||
@@ -339,15 +379,35 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
|
||||
// Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) |
|
||||
// Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) |
|
||||
// Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) |
|
||||
|
||||
// Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4:
|
||||
// Size | BLOCK_M | | BLOCK_M | | BLOCK_M | | BLOCK_M | | || Size | BLOCK_M | | BLOCK_M | | |
|
||||
// M | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || M | 0 ... 31 | | 0 ... 31 | | Vector |
|
||||
// Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element|
|
||||
// Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------|
|
||||
// Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] |
|
||||
// Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] |
|
||||
// Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] |
|
||||
// Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] |
|
||||
// Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] |
|
||||
// Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] |
|
||||
// Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] |
|
||||
// Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] |
|
||||
// Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] |
|
||||
// Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] |
|
||||
// Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] |
|
||||
// Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] |
|
||||
// Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] |
|
||||
// Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] |
|
||||
// clang-format on
|
||||
static constexpr uint32_t VW = vectorSize(AFragT{});
|
||||
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 1 element
|
||||
// We need to know where they start
|
||||
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
|
||||
(threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col
|
||||
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
|
||||
(threadIdx.x / BLOCK_M)); // Col
|
||||
|
||||
// Flatten to 1D row_major offsets.
|
||||
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
||||
@@ -369,7 +429,7 @@ template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
|
||||
__device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
{
|
||||
// clang-format off
|
||||
// Register Mapping for 128x16: || Register Mapping for 64x32:
|
||||
// Register Mapping for 128x16 for FP8: || Register Mapping for 64x32 for FP8:
|
||||
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
|
||||
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
@@ -406,6 +466,28 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
|
||||
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
|
||||
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
||||
|
||||
// Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4:
|
||||
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
|
||||
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
||||
// Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] |
|
||||
// Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] |
|
||||
// Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] |
|
||||
// Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] |
|
||||
// Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] |
|
||||
// Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] |
|
||||
// Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] |
|
||||
// Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] |
|
||||
// Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] |
|
||||
// Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] |
|
||||
// Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] |
|
||||
// Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] |
|
||||
// Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] |
|
||||
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
|
||||
// clang-format on
|
||||
|
||||
static constexpr int32_t WAVE_SIZE = 64;
|
||||
@@ -430,23 +512,34 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
|
||||
|
||||
// BLOCK_K is a stride in B matrix
|
||||
auto startOffset = col_major(startCoord2D, BLOCK_K);
|
||||
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K);
|
||||
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K);
|
||||
auto startOffset = col_major(
|
||||
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
|
||||
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
auto kMajorOffset =
|
||||
col_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
|
||||
using BRawT = typename scalar_type<BFragT>::type;
|
||||
using BScalarFragT = vector_type<BRawT, chunk_size>::type;
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
|
||||
union
|
||||
{
|
||||
BFragT frag;
|
||||
BScalarFragT chunks[2];
|
||||
BScalarFragT chunks[num_chunks];
|
||||
} fragB{};
|
||||
|
||||
auto* fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset);
|
||||
fragB.chunks[0] = *fragPtr;
|
||||
fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
|
||||
fragB.chunks[1] = *fragPtr;
|
||||
const BScalarFragT* fragPtr;
|
||||
|
||||
for(index_t chunk = 0; chunk < num_chunks; chunk++)
|
||||
{
|
||||
fragPtr =
|
||||
reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + chunk * kMajorOffset);
|
||||
fragB.chunks[chunk] = *fragPtr;
|
||||
}
|
||||
|
||||
return fragB.frag;
|
||||
}
|
||||
@@ -506,15 +599,56 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
|
||||
// Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) |
|
||||
// Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) |
|
||||
|
||||
// Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4:
|
||||
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
|
||||
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
||||
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
||||
// Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] |
|
||||
// Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] |
|
||||
// Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] |
|
||||
// Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] |
|
||||
// Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] |
|
||||
// Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] |
|
||||
// Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] |
|
||||
// Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] |
|
||||
// Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] |
|
||||
// Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] |
|
||||
// Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] |
|
||||
// Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] |
|
||||
// Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] |
|
||||
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
|
||||
|
||||
// Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4:
|
||||
// Size | BLOCK_N | | BLOCK_N | | BLOCK_N | | BLOCK_N | | || Size | BLOCK_N | | BLOCK_N | | |
|
||||
// N | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || N | 0 ... 31 | | 0 ... 31 | | Vector |
|
||||
// Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element|
|
||||
// Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------|
|
||||
// Reg 0 [0:7] | K0K1 | x(0,N) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] |
|
||||
// Reg 0 [8:15] | K2K3 | x(0,N) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] |
|
||||
// Reg 0 [16:23] | K4K5 | x(0,N) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] |
|
||||
// Reg 0 [24:31] | K6K7 | x(0,N) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] |
|
||||
// Reg 1 [0:7] | K8K9 | x(0,N) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] |
|
||||
// Reg 1 [8:15] | K10K11 | x(0,N) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] |
|
||||
// Reg 1 [16:23] | K12K13 | x(0,N) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] |
|
||||
// Reg 1 [24:31] | K14K15 | x(0,N) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] |
|
||||
// Reg 2 [0:7] | K16K17 | x(0,N) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] |
|
||||
// Reg 2 [8:15] | K18K19 | x(0,N) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] |
|
||||
// Reg 2 [16:23] | K20K21 | x(0,N) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] |
|
||||
// Reg 2 [24:31] | K22K23 | x(0,N) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] |
|
||||
// Reg 3 [0:7] | K24K25 | x(0,N) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] |
|
||||
// Reg 3 [8:15] | K26K27 | x(0,N) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] |
|
||||
// Reg 3 [16:23] | K28K29 | x(0,N) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] |
|
||||
// Reg 3 [24:31] | K30K31 | x(0,N) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] |
|
||||
// clang-format on
|
||||
static constexpr uint32_t VW = vectorSize(BFragT{});
|
||||
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 1 element
|
||||
// We need to know where to start
|
||||
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row
|
||||
threadIdx.x % BLOCK_N); // Col
|
||||
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N), // Row
|
||||
threadIdx.x % BLOCK_N); // Col
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
@@ -766,15 +900,24 @@ template <typename AType,
|
||||
typename AccType,
|
||||
int32_t BLOCK_M,
|
||||
int32_t BLOCK_N,
|
||||
int32_t BLOCK_K>
|
||||
int32_t BLOCK_K,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
__global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
{
|
||||
constexpr int WAVE_SIZE = 64;
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
|
||||
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AFragT =
|
||||
vector_type<AType,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using BFragT =
|
||||
vector_type<BType,
|
||||
BLOCK_K * BLOCK_N / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
@@ -786,10 +929,23 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
auto fragAcc = AccumFragT{0};
|
||||
|
||||
// Load the inputs.
|
||||
// A = col major, BLOCK_M x BLOCK_K
|
||||
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
// B = col major, BLOCK_K x BLOCK_N
|
||||
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
}
|
||||
else
|
||||
{
|
||||
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
printf("This layout is not implemented\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
|
||||
}
|
||||
|
||||
// Matrix multiply-accumulate using MFMA units
|
||||
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
||||
@@ -801,8 +957,14 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
|
||||
}
|
||||
|
||||
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
||||
storeC(c, fragC);
|
||||
if constexpr(is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AType,
|
||||
@@ -813,7 +975,10 @@ template <typename AType,
|
||||
int32_t BLOCK_M,
|
||||
int32_t BLOCK_N,
|
||||
int32_t BLOCK_K,
|
||||
int32_t BLOCK_X>
|
||||
int32_t BLOCK_X,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
__global__ void
|
||||
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
|
||||
{
|
||||
@@ -821,8 +986,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
|
||||
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AFragT =
|
||||
vector_type<AType,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using BFragT =
|
||||
vector_type<BType,
|
||||
BLOCK_K * BLOCK_N / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
@@ -838,13 +1009,27 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
auto fragXb = BScaleFragT{};
|
||||
|
||||
// Load the inputs.
|
||||
// A = col major, BLOCK_M x BLOCK_K
|
||||
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
|
||||
a, xa, fragXa);
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
fragA =
|
||||
load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
|
||||
a, xa, fragXa);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("This layout is not implemented\n");
|
||||
}
|
||||
|
||||
// B = col major, BLOCK_K x BLOCK_N
|
||||
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
|
||||
b, xb, fragXb);
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
printf("This layout is not implemented\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
fragB =
|
||||
load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
|
||||
b, xb, fragXb);
|
||||
}
|
||||
|
||||
// Scaled Matrix multiply-accumulate using MFMA units
|
||||
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
||||
@@ -860,8 +1045,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
|
||||
}
|
||||
|
||||
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
||||
storeC(c, fragC);
|
||||
if constexpr(is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -993,8 +1184,7 @@ struct TestMXMFMA
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/64
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
|
||||
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
@@ -1012,11 +1202,9 @@ struct TestMXMFMA
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
|
||||
break;
|
||||
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
|
||||
@@ -1026,6 +1214,14 @@ struct TestMXMFMA
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
break;
|
||||
case 4:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1., 1.});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1., 1.});
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
break;
|
||||
default:
|
||||
// all initial values are representable in FP8, BF8
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
|
||||
@@ -1207,6 +1403,11 @@ struct TestMFMA
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
|
||||
break;
|
||||
case 4:
|
||||
// FP4 values case
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5});
|
||||
break;
|
||||
default:
|
||||
// all initial values are representable in FP8, BF8
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
|
||||
|
||||
Reference in New Issue
Block a user