Merge commit '57e0f5df29abefd919c334c994628a994ba2868c' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-19 22:06:56 +00:00
parent 0b87df9c4a
commit 9d088bc569
15 changed files with 1602 additions and 588 deletions

View File

@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::bf6_convert_rne;
@@ -41,6 +42,11 @@ TEST(BF6, ConvertFP32Nearest)
ASSERT_NEAR(max_bf6,
type_convert<float>(bf6_convert_rne(std::numeric_limits<float>::infinity())),
0.0f);
// convert float +/-30 to bf6 and back, check if clipped to +/-max_bf6
ASSERT_NEAR(-max_bf6, type_convert<float>(bf6_convert_rne(-30.0f)), 0.0f);
ASSERT_NEAR(max_bf6, type_convert<float>(bf6_convert_rne(30.0f)), 0.0f);
// convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0
float less_than_subnorm = 0.03125f;
ASSERT_NEAR(0.0f, type_convert<float>(bf6_convert_rne(less_than_subnorm)), 0.0f);
@@ -266,21 +272,18 @@ TEST(BF6, TestAsType16x1)
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec);
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{test_vec};
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).unpack(i),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
@@ -329,23 +332,23 @@ TEST(BF6, TestAsType16x2)
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec[i]);
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{test_vec[i]};
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
@@ -369,20 +372,86 @@ TEST(BF6, TestAsType32x1)
vector_type<bf6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
ASSERT_EQ(right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{}.pack(test_vec);
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{test_vec};
});
// copy the vector
vector_type<bf6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).unpack(i),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
TEST(BF6, TestAllValues)
{
constexpr std::array<float, 64> e3m2ValuesOCP = {
// clang-format off
0.0000000000, 0.0625000000, 0.1250000000, 0.1875000000,
0.2500000000, 0.3125000000, 0.3750000000, 0.4375000000,
0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000,
1.0000000000, 1.2500000000, 1.5000000000, 1.7500000000,
2.0000000000, 2.5000000000, 3.0000000000, 3.5000000000,
4.0000000000, 5.0000000000, 6.0000000000, 7.0000000000,
8.0000000000, 10.0000000000, 12.0000000000, 14.0000000000,
16.0000000000, 20.0000000000, 24.0000000000, 28.0000000000,
-0.0000000000, -0.0625000000, -0.1250000000, -0.1875000000,
-0.2500000000, -0.3125000000, -0.3750000000, -0.4375000000,
-0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000,
-1.0000000000, -1.2500000000, -1.5000000000, -1.7500000000,
-2.0000000000, -2.5000000000, -3.0000000000, -3.5000000000,
-4.0000000000, -5.0000000000, -6.0000000000, -7.0000000000,
-8.0000000000, -10.0000000000, -12.0000000000, -14.0000000000,
-16.0000000000, -20.0000000000, -24.0000000000, -28.0000000000
// clang-format on
};
constexpr uint8_t e3m2BitsOCP[] = {
// clang-format off
0b000000, 0b000001, 0b000010, 0b000011,
0b000100, 0b000101, 0b000110, 0b000111,
0b001000, 0b001001, 0b001010, 0b001011,
0b001100, 0b001101, 0b001110, 0b001111,
0b010000, 0b010001, 0b010010, 0b010011,
0b010100, 0b010101, 0b010110, 0b010111,
0b011000, 0b011001, 0b011010, 0b011011,
0b011100, 0b011101, 0b011110, 0b011111,
0b100000, 0b100001, 0b100010, 0b100011,
0b100100, 0b100101, 0b100110, 0b100111,
0b101000, 0b101001, 0b101010, 0b101011,
0b101100, 0b101101, 0b101110, 0b101111,
0b110000, 0b110001, 0b110010, 0b110011,
0b110100, 0b110101, 0b110110, 0b110111,
0b111000, 0b111001, 0b111010, 0b111011,
0b111100, 0b111101, 0b111110, 0b111111
// clang-format on
};
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
if(ck_logging)
printf("BF6 Table\n");
ck::static_for<0, 64, 1>{}([&](auto i) {
float fp = type_convert<float>(bf6_t(e3m2BitsOCP[i]));
ASSERT_EQ(fp, e3m2ValuesOCP[i]);
bf6_t bf6 = type_convert<bf6_t>(e3m2ValuesOCP[i]);
ASSERT_EQ(bf6 & 0x3F, e3m2BitsOCP[i] & 0x3F);
if(ck_logging)
{
// Print the binary representation
printf("Bits: 0b");
for(int j = 5; j >= 0; --j)
{
printf("%c", (e3m2BitsOCP[i] & (1 << j)) ? '1' : '0');
}
printf(", 0x%02X, Value: %f\n", e3m2BitsOCP[i], e3m2ValuesOCP[i]);
}
});
}

View File

@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
#include "ck/utility/env.hpp"
using ck::e8m0_bexp_t;
using ck::f4_convert_rne;
@@ -38,6 +39,11 @@ TEST(FP4, ConvertFP32Nearest)
// convert maximal float to fp4 and back, check if clipped to 6.0
ASSERT_NEAR(
max_fp4, type_convert<float>(f4_convert_rne(std::numeric_limits<float>::max())), abs_tol);
// convert +/-7.0 to fp4 and back, check if clipped to +/-6.0
ASSERT_NEAR(-max_fp4, type_convert<float>(f4_convert_rne(-7.0f)), 0.0);
ASSERT_NEAR(max_fp4, type_convert<float>(f4_convert_rne(7.0f)), 0.0);
// positive norm float value to fp4 and back, check if holds
float pos_float = 1.0f;
ASSERT_NEAR(pos_float, type_convert<float>(f4_convert_rne(pos_float)), abs_tol);
@@ -468,3 +474,54 @@ TEST(FP4, TestAsType32)
test_vec.at(i + 1));
});
}
TEST(FP4, TestAllValues)
{
constexpr std::array<float, 16> e2m1ValuesOCP = {
// clang-format off
0.0000000000, 0.5000000000,
1.0000000000, 1.5000000000,
2.0000000000, 3.0000000000,
4.0000000000, 6.0000000000,
-0.0000000000, -0.5000000000,
-1.0000000000, -1.5000000000,
-2.0000000000, -3.0000000000,
-4.0000000000, -6.0000000000
// clang-format on
};
constexpr uint8_t e2m1BitsOCP[] = {
// clang-format off
0b0000, 0b0001,
0b0010, 0b0011,
0b0100, 0b0101,
0b0110, 0b0111,
0b1000, 0b1001,
0b1010, 0b1011,
0b1100, 0b1101,
0b1110, 0b1111
// clang-format on
};
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
if(ck_logging)
printf("FP4 Table\n");
ck::static_for<0, 16, 1>{}([&](auto i) {
float fp = type_convert<float>(f4_t(e2m1BitsOCP[i]));
ASSERT_EQ(fp, e2m1ValuesOCP[i]);
f4_t fp4 = type_convert<f4_t>(e2m1ValuesOCP[i]);
ASSERT_EQ(fp4 & 0xF, e2m1BitsOCP[i] & 0xF);
if(ck_logging)
{
// Print the binary representation
printf("Bits: 0b");
for(int j = 3; j >= 0; --j)
{
printf("%c", (e2m1BitsOCP[i] & (1 << j)) ? '1' : '0');
}
printf(", 0x%02X, Value: %f\n", e2m1BitsOCP[i], e2m1ValuesOCP[i]);
}
});
}

View File

@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
@@ -34,6 +35,11 @@ TEST(FP6, ConvertFP32Nearest)
ASSERT_NEAR(0.0f, type_convert<float>(f6_convert_rne(0.0f)), 0.0f);
// convert maximal f6_t to float and check if equal to max_fp6
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(max_fp6)), 0.0f);
// convert maximal +/-8.0 to fp6 and check if equal to +/-max_fp6
ASSERT_NEAR(-max_fp6, type_convert<float>(f6_convert_rne(-8.0f)), 0.0f);
ASSERT_NEAR(max_fp6, type_convert<float>(f6_convert_rne(8.0f)), 0.0f);
// convert maximal float to fp6 and back, check if clipped to max_fp6
ASSERT_NEAR(
max_fp6, type_convert<float>(f6_convert_rne(std::numeric_limits<float>::max())), 0.0f);
@@ -265,20 +271,24 @@ TEST(FP6, TestAsType16x1)
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{test_vec};
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i),
static_cast<f6_t>(test_vec[static_cast<int>(i)]))
<< " i = " << i << "; left = "
<< type_convert<float>(left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).unpack(i))
<< " -- right = "
<< type_convert<float>(static_cast<f6_t>(test_vec[static_cast<int>(i)])) << " ("
<< static_cast<int>(test_vec[static_cast<int>(i)]) << ")" << std::endl;
});
}
@@ -327,23 +337,23 @@ TEST(FP6, TestAsType16x2)
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{test_vec[i]};
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{}).unpack(idx_element),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
@@ -367,19 +377,77 @@ TEST(FP6, TestAsType32x1)
vector_type<f6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
ASSERT_EQ(right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).unpack(i), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{test_vec};
});
// copy the vector
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
ASSERT_EQ(left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).unpack(i),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
TEST(FP6, TestAllValues)
{
constexpr std::array<float, 64> e2m3ValuesOCP = {
// clang-format off
0.0000000000, 0.1250000000, 0.2500000000, 0.3750000000, 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000,
1.0000000000, 1.1250000000, 1.2500000000, 1.3750000000, 1.5000000000, 1.6250000000, 1.7500000000, 1.8750000000,
2.0000000000, 2.2500000000, 2.5000000000, 2.7500000000, 3.0000000000, 3.2500000000, 3.5000000000, 3.7500000000,
4.0000000000, 4.5000000000, 5.0000000000, 5.5000000000, 6.0000000000, 6.5000000000, 7.0000000000, 7.5000000000,
-0.0000000000, -0.1250000000, -0.2500000000, -0.3750000000, -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000,
-1.0000000000, -1.1250000000, -1.2500000000, -1.3750000000, -1.5000000000, -1.6250000000, -1.7500000000, -1.8750000000,
-2.0000000000, -2.2500000000, -2.5000000000, -2.7500000000, -3.0000000000, -3.2500000000, -3.5000000000, -3.7500000000,
-4.0000000000, -4.5000000000, -5.0000000000, -5.5000000000, -6.0000000000, -6.5000000000, -7.0000000000, -7.5000000000
// clang-format on
};
constexpr uint8_t e2m3BitsOCP[] = {
// clang-format off
0b000000, 0b000001, 0b000010, 0b000011,
0b000100, 0b000101, 0b000110, 0b000111,
0b001000, 0b001001, 0b001010, 0b001011,
0b001100, 0b001101, 0b001110, 0b001111,
0b010000, 0b010001, 0b010010, 0b010011,
0b010100, 0b010101, 0b010110, 0b010111,
0b011000, 0b011001, 0b011010, 0b011011,
0b011100, 0b011101, 0b011110, 0b011111,
0b100000, 0b100001, 0b100010, 0b100011,
0b100100, 0b100101, 0b100110, 0b100111,
0b101000, 0b101001, 0b101010, 0b101011,
0b101100, 0b101101, 0b101110, 0b101111,
0b110000, 0b110001, 0b110010, 0b110011,
0b110100, 0b110101, 0b110110, 0b110111,
0b111000, 0b111001, 0b111010, 0b111011,
0b111100, 0b111101, 0b111110, 0b111111
// clang-format on
};
const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING));
if(ck_logging)
printf("FP6 Table\n");
ck::static_for<0, 64, 1>{}([&](auto i) {
float fp = type_convert<float>(f6_t(e2m3BitsOCP[i]));
ASSERT_EQ(fp, e2m3ValuesOCP[i]);
f6_t fp6 = type_convert<f6_t>(e2m3ValuesOCP[i]);
ASSERT_EQ(fp6 & 0x3F, e2m3BitsOCP[i] & 0x3F);
if(ck_logging)
{
// Print the binary representation
printf("Bits: 0b");
for(int j = 5; j >= 0; --j)
{
printf("%c", (e2m3BitsOCP[i] & (1 << j)) ? '1' : '0');
}
printf(", 0x%02X, Value: %f\n", e2m3BitsOCP[i], e2m3ValuesOCP[i]);
}
});
}

View File

@@ -5,9 +5,12 @@
#include "mx_mfma_op.hpp"
using ck::bf6_t;
using ck::bf8_t;
using ck::e8m0_bexp_t;
using ck::f4_t;
using ck::f4x2_pk_t;
using ck::f6_t;
using ck::f8_t;
using ck::half_t;
using ck::type_convert;
@@ -17,13 +20,15 @@ using ck::type_convert;
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_km_kn_nm_test(ck::index_t init)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename AType,
typename BType,
typename CType,
ck::MFMA_F8F6F4 mfma>
bool run_mfma_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
@@ -53,74 +58,153 @@ bool run_mfma_km_kn_nm_test(ck::index_t init)
return pass;
}
const ck::index_t common_init = -4; // set to "< 0" for test-specific initializations
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 5;
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
f8_t,
f8_t,
half_t,
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
TEST(MFMA, BF8MFMA16x16x128)
{
auto AB_init = 5;
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
half_t,
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_mk_kn_mn_test(ck::index_t init)
TEST(MFMA, FP4MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mfma_kernel = ck::
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
bool pass = true;
pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mfma_kernel, init);
return pass;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f4_t, f4_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP4MFMA16x16x128)
TEST(MFMA, FP6MFMA16x16x128)
{
auto AB_init = 4;
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f6_t, f6_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, BF6MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
float,
ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, BF8MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
float,
ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP4MFMA32x32x64)
{
auto AB_init = 4;
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f4_t, f4_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass =
run_mfma_test<ALayout, BLayout, CLayout, f6_t, f6_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, BF6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
half_t,
ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
@@ -129,15 +213,18 @@ TEST(MFMA, FP4MFMA32x32x64)
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename AType,
typename BType,
typename CType,
ck::MFMA_F8F6F4 mfma>
bool run_mxmfma_test(ck::index_t init)
{
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported");
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported
using ScaleType = ck::e8m0_bexp_t; // biased exponent type
@@ -181,34 +268,170 @@ bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
TEST(MXMFMA, MXFP8MFMA16x16x128)
{
auto AB_init = 5;
auto pass =
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f8_t,
f8_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64)
{
auto AB_init = 5;
auto pass =
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f8_t,
f8_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF8MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF8MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf8_t,
bf8_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP6MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f6_t,
f6_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f6_t,
f6_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF6MFMA16x16x128)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXBF6MFMA32x32x64)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
bf6_t,
bf6_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP4MFMA16x16x128)
{
auto AB_init = 4;
auto pass =
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f4_t,
f4_t,
float,
ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP4MFMA32x32x64)
{
auto AB_init = 4;
auto pass =
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(
AB_init);
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto AB_init = (common_init < 0) ? 5 : common_init;
auto pass = run_mxmfma_test<ALayout,
BLayout,
CLayout,
f4_t,
f4_t,
half_t,
ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}

View File

@@ -151,6 +151,8 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// clang-format on
static_assert(!is_packed_type_v<AType>, "Packed type is not supported");
static constexpr int32_t WAVE_SIZE = 64;
// Here we want to load from rows of A in chunks of 16 elements each.
@@ -270,12 +272,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] |
// Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
// FP8 chunk_size = 16, num_chunks = 2, packed_size = 1
// FP4 chunk_size = 32, num_chunks = 1, packed_size = 2
// FP6 chunk_size = 32, num_chunks = 1, packed_size = 32
constexpr index_t num_chunks = is_packed_type_v<AType> ? 1 : 2;
// Here we want to load from rows of A in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
constexpr uint32_t chunk_size = is_packed_type_v<AType> ? 32 : 16;
// each chunk is separated by offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
@@ -283,43 +301,35 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D =
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
// FP8/6/4 Row {0-31} | {0-15}
// FP8 Col {0, 16} | {0, 16, 32, 48}
// FP6/4 Col {0, 32} | {0, 32, 64, 96}
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, (threadIdx.x / BLOCK_M) * chunk_size);
// auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
row_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
using ARawT = typename scalar_type<AFragT>::type;
using AScalarChunkT = vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
union
{
AFragT frag;
AScalarFragT chunks[num_chunks];
AScalarChunkT chunks[num_chunks];
} fragA{};
const AScalarFragT* fragPtr;
const AScalarChunkT* fragPtr;
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K) / packed_size_v<AType>;
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K) / packed_size_v<AType>;
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
{
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragPtr = reinterpret_cast<AScalarChunkT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragA.chunks[chunk_idx] = *fragPtr;
}
@@ -488,12 +498,27 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] |
// Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
// FP8 chunk_size = 16, num_chunks = 2, packed_size = 1
// FP4 chunk_size = 32, num_chunks = 1, packed_size = 2
// FP6 chunk_size = 32, num_chunks = 1, packed_size = 32
constexpr index_t num_chunks = is_packed_type_v<BType> ? 1 : 2;
// Here we want to load from cols of B in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
constexpr uint32_t chunk_size = is_packed_type_v<BType> ? 32 : 16;
// each chunk is separated by an offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64
@@ -501,44 +526,36 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D =
std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48}
threadIdx.x % BLOCK_N); // Col {0-31} | {0-15}
// FP8/6/4 Col {0-31} | {0-15}
// FP8 Row {0, 16} | {0, 16, 32, 48}
// FP6/4 Row {0, 32} | {0, 32, 64, 96}
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, threadIdx.x % BLOCK_N);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
col_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
using BRawT = typename scalar_type<BFragT>::type;
using BScalarFragT = vector_type<BRawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 1 : 2);
using BRawT = typename scalar_type<BFragT>::type;
using BScalarChunkT = vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
union
{
BFragT frag;
BScalarFragT chunks[num_chunks];
BScalarChunkT chunks[num_chunks];
} fragB{};
const BScalarFragT* fragPtr;
const BScalarChunkT* fragPtr;
for(index_t chunk = 0; chunk < num_chunks; chunk++)
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(startCoord2D, BLOCK_K) / packed_size_v<BType>;
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K) / packed_size_v<BType>;
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
{
fragPtr =
reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + chunk * kMajorOffset);
fragB.chunks[chunk] = *fragPtr;
fragPtr = reinterpret_cast<BScalarChunkT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragB.chunks[chunk_idx] = *fragPtr;
}
return fragB.frag;
@@ -904,20 +921,22 @@ template <typename AType,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void matmul(const AType* a, const BType* b, CType* c)
__global__ void matmul(const typename packed_type<AType>::type* a,
const typename packed_type<BType>::type* b,
CType* c)
{
using PackedAType = typename packed_type<AType>::type;
constexpr auto packed_size_a = packed_type<AType>::packed_size;
using PackedBType = typename packed_type<BType>::type;
constexpr auto packed_size_b = packed_type<BType>::packed_size;
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
@@ -931,11 +950,11 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Load the inputs.
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
fragA = load_A_row_major<PackedAType, AFragT, BLOCK_M, BLOCK_K>(a);
}
else
{
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
fragA = load_A_col_major<PackedAType, AFragT, BLOCK_M, BLOCK_K>(a);
}
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
@@ -944,7 +963,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
}
else
{
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
fragB = load_B_col_major<PackedBType, BFragT, BLOCK_K, BLOCK_N>(b);
}
// Matrix multiply-accumulate using MFMA units
@@ -979,21 +998,24 @@ template <typename AType,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
__global__ void matmul(const packed_type_t<AType>* a,
const ScaleType* xa,
const packed_type_t<BType>* b,
const ScaleType* xb,
CType* c)
{
using PackedAType = packed_type_t<AType>;
constexpr auto packed_size_a = packed_size_v<AType>;
using PackedBType = packed_type_t<BType>;
constexpr auto packed_size_b = packed_size_v<BType>;
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
@@ -1011,9 +1033,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
// Load the inputs.
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
fragA =
load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
fragA = load_mx_A_row_major<PackedAType,
AFragT,
ScaleType,
AScaleFragT,
BLOCK_M,
BLOCK_K,
BLOCK_X>(a, xa, fragXa);
}
else
{
@@ -1026,9 +1052,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
}
else
{
fragB =
load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
fragB = load_mx_B_col_major<PackedBType,
BFragT,
ScaleType,
BScaleFragT,
BLOCK_K,
BLOCK_N,
BLOCK_X>(b, xb, fragXb);
}
// Scaled Matrix multiply-accumulate using MFMA units
@@ -1151,6 +1181,11 @@ template <typename DeviceMFMA,
index_t BLOCK_X>
struct TestMXMFMA
{
using PackedAType = typename packed_type<ADataType>::type;
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
using PackedBType = typename packed_type<BDataType>::type;
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
@@ -1167,11 +1202,11 @@ struct TestMXMFMA
}
};
Tensor<ADataType> a_m_k(
Tensor<PackedAType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<ScaleType> a_scales(
f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{}));
Tensor<BDataType> b_n_k(
Tensor<PackedBType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<ScaleType> b_scales(
f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{}));
@@ -1183,51 +1218,44 @@ struct TestMXMFMA
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.5f}});
// NOTE: not all numbers are representable in FP8, BF8, etc.
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<PackedBType, 1>{});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<PackedBType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k.GenerateTensorValue(GeneratorTensor_3<PackedAType>{-2.0, 2.0});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
b_n_k.GenerateTensorValue(GeneratorTensor_3<PackedBType>{-2.0, 2.0});
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_4<PackedAType>(0, 1, time(nullptr)));
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(0, 1));
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
break;
case 4:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1., 1.});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1., 1.});
b_n_k.GenerateTensorValue(GeneratorTensor_4<PackedBType>(0, 1, time(nullptr) / 2));
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
a_m_k.GenerateTensorValue(GeneratorTensor_2<PackedAType>{-6, 7}); // Z[-6,6]
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
b_n_k.GenerateTensorValue(GeneratorTensor_2<PackedBType>{-6, 7}); // Z[-6,6]
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
@@ -1272,9 +1300,9 @@ struct TestMXMFMA
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<PackedAType>& a = std::get<0>(host_tensors);
const Tensor<ScaleType>& a_scales = std::get<1>(host_tensors);
const Tensor<BDataType>& b = std::get<2>(host_tensors);
const Tensor<PackedBType>& b = std::get<2>(host_tensors);
const Tensor<ScaleType>& b_scales = std::get<3>(host_tensors);
Tensor<CDataType>& c_host = std::get<4>(host_tensors);
Tensor<CDataType>& c_device = std::get<5>(host_tensors);
@@ -1356,6 +1384,12 @@ template <typename DeviceMFMA,
index_t BLOCK_K>
struct TestMFMA
{
using PackedAType = typename packed_type<ADataType>::type;
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
using PackedBType = typename packed_type<BDataType>::type;
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
@@ -1372,9 +1406,9 @@ struct TestMFMA
}
};
Tensor<ADataType> a_m_k(
Tensor<PackedAType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_n_k(
Tensor<PackedBType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
@@ -1384,34 +1418,30 @@ struct TestMFMA
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{0.625f});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<PackedBType, 1>{});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<PackedAType>{1.0f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<PackedBType>{1.0f});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
// expect small round off errors that lead to FP8MFMA32x32x64 failures
a_m_k.GenerateTensorValue(GeneratorTensor_3<PackedAType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_3<PackedBType>{-5, 5});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break;
case 4:
// FP4 values case
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5});
// expect small round off errors that lead to FP8MFMA32x32x64 failures
a_m_k.GenerateTensorValue(GeneratorTensor_4<PackedAType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<PackedBType>(1, 3));
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
// all initial values are representable in FP8/6, BF8/6 FP4 is missing 5
a_m_k.GenerateTensorValue(GeneratorTensor_2<PackedAType>{-6, 7}); // Z[-6,6]
b_n_k.GenerateTensorValue(GeneratorTensor_2<PackedBType>{-6, 7});
break;
}
@@ -1453,10 +1483,10 @@ struct TestMFMA
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
const Tensor<PackedAType>& a = std::get<0>(host_tensors);
const Tensor<PackedBType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -1464,8 +1494,8 @@ struct TestMFMA
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<PackedAType,
PackedBType,
CDataType,
CPUAccDataType,
PassThrough,