Merge commit '8a0d659f92897e1ae99e4dc0ea4842a2c78170ab' into develop

This commit is contained in:
github-actions[bot]
2025-05-06 15:07:46 +00:00
parent bf90418b06
commit b96328e63f
8 changed files with 610 additions and 79 deletions

View File

@@ -6,6 +6,8 @@
#include "mx_mfma_op.hpp"
using ck::e8m0_bexp_t;
using ck::f4_t;
using ck::f4x2_pk_t;
using ck::f8_t;
using ck::half_t;
using ck::type_convert;
@@ -16,7 +18,7 @@ using ck::type_convert;
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_test(ck::index_t init)
bool run_mfma_km_kn_nm_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
@@ -30,7 +32,8 @@ bool run_mfma_test(ck::index_t init)
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
const auto mfma_kernel = ck::
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
bool pass = true;
@@ -52,15 +55,72 @@ bool run_mfma_test(ck::index_t init)
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 4;
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
auto AB_init = 5;
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto AB_init = 5;
auto pass = run_mfma_km_kn_nm_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_mk_kn_mn_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mfma_kernel = ck::
matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, ALayout, BLayout, CLayout>;
bool pass = true;
pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mfma_kernel, init);
return pass;
}
TEST(MFMA, FP4MFMA16x16x128)
{
auto AB_init = 4;
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::F32_16x16x128>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP4MFMA32x32x64)
{
auto AB_init = 4;
auto pass = run_mfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}
@@ -70,7 +130,7 @@ TEST(MFMA, FP8MFMA32x32x64)
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mxmfma_test(ck::index_t init)
bool run_mxmfma_mk_kn_mn_test(ck::index_t init)
{
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
@@ -88,8 +148,18 @@ bool run_mxmfma_test(ck::index_t init)
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
constexpr auto BLOCK_X = 32; // scaling vector size
const auto mx_mfma_kernel =
ck::matmul<AType, BType, ScaleType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_X>;
const auto mx_mfma_kernel = ck::matmul<AType,
BType,
ScaleType,
CType,
AccType,
BLOCK_M,
BLOCK_N,
BLOCK_K,
BLOCK_X,
ALayout,
BLayout,
CLayout>;
bool pass = true;
@@ -111,14 +181,34 @@ bool run_mxmfma_test(ck::index_t init)
TEST(MXMFMA, MXFP8MFMA16x16x128)
{
auto AB_init = 7;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
auto AB_init = 5;
auto pass =
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64)
{
auto AB_init = 7;
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
auto AB_init = 5;
auto pass =
run_mxmfma_mk_kn_mn_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP4MFMA16x16x128)
{
auto AB_init = 4;
auto pass =
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(
AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP4MFMA32x32x64)
{
auto AB_init = 4;
auto pass =
run_mxmfma_mk_kn_mn_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(
AB_init);
EXPECT_TRUE(pass);
}

View File

@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
@@ -111,7 +112,7 @@ template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_col_major(AType const* input_ptr)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Register Mapping for 16x128 for FP8: || Register Mapping for 32x64 for FP8:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
@@ -176,13 +177,19 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M);
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M);
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, vectorSize(AFragT{})>::type;
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT =
vector_type<ARawT,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
AScalarFragT fragA{};
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
#pragma unroll
for(int chunk = 0; chunk < 2; chunk++)
for(int chunk = 0; chunk < num_chunks; chunk++)
{
#pragma unroll
for(uint32_t i = 0; i < chunk_size; i++)
@@ -241,6 +248,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] |
// Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] |
// Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] |
// Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] |
// Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] |
// Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] |
// Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] |
// Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] |
// Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] |
// Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] |
// Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] |
// Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] |
// Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] |
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
@@ -265,23 +294,34 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K);
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K);
auto startOffset = row_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
row_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
union
{
AFragT frag;
AScalarFragT chunks[2];
AScalarFragT chunks[num_chunks];
} fragA{};
auto* fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset);
fragA.chunks[0] = *fragPtr;
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
fragA.chunks[1] = *fragPtr;
const AScalarFragT* fragPtr;
for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++)
{
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset +
chunk_idx * kMajorOffset);
fragA.chunks[chunk_idx] = *fragPtr;
}
return fragA.frag;
}
@@ -339,15 +379,35 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
// Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) |
// Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) |
// Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) |
// Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4:
// Size | BLOCK_M | | BLOCK_M | | BLOCK_M | | BLOCK_M | | || Size | BLOCK_M | | BLOCK_M | | |
// M | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || M | 0 ... 31 | | 0 ... 31 | | Vector |
// Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element|
// Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------|
// Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] |
// Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] |
// Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] |
// Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] |
// Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] |
// Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] |
// Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] |
// Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] |
// Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] |
// Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] |
// Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] |
// Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] |
// Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] |
// Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] |
// Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] |
// Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] |
// clang-format on
static constexpr uint32_t VW = vectorSize(AFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where they start
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M)); // Col
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
@@ -369,7 +429,7 @@ template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
__device__ BFragT load_B_col_major(BType const* input_ptr)
{
// clang-format off
// Register Mapping for 128x16: || Register Mapping for 64x32:
// Register Mapping for 128x16 for FP8: || Register Mapping for 64x32 for FP8:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
@@ -406,6 +466,28 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] |
// Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] |
// Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] |
// Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] |
// Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] |
// Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] |
// Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] |
// Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] |
// Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] |
// Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] |
// Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] |
// Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] |
// Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] |
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
@@ -430,23 +512,34 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(startCoord2D, BLOCK_K);
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K);
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K);
auto startOffset = col_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
col_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
using BRawT = typename scalar_type<BFragT>::type;
using BScalarFragT = vector_type<BRawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 1 : 2);
union
{
BFragT frag;
BScalarFragT chunks[2];
BScalarFragT chunks[num_chunks];
} fragB{};
auto* fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset);
fragB.chunks[0] = *fragPtr;
fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
fragB.chunks[1] = *fragPtr;
const BScalarFragT* fragPtr;
for(index_t chunk = 0; chunk < num_chunks; chunk++)
{
fragPtr =
reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + chunk * kMajorOffset);
fragB.chunks[chunk] = *fragPtr;
}
return fragB.frag;
}
@@ -506,15 +599,56 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
// Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) |
// Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) |
// Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] |
// Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] |
// Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] |
// Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] |
// Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] |
// Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] |
// Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] |
// Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] |
// Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] |
// Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] |
// Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] |
// Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] |
// Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] |
// Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] |
// Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] |
// Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] |
// Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4:
// Size | BLOCK_N | | BLOCK_N | | BLOCK_N | | BLOCK_N | | || Size | BLOCK_N | | BLOCK_N | | |
// N | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || N | 0 ... 31 | | 0 ... 31 | | Vector |
// Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element|
// Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------|
// Reg 0 [0:7] | K0K1 | x(0,N) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] |
// Reg 0 [8:15] | K2K3 | x(0,N) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] |
// Reg 0 [16:23] | K4K5 | x(0,N) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] |
// Reg 0 [24:31] | K6K7 | x(0,N) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] |
// Reg 1 [0:7] | K8K9 | x(0,N) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] |
// Reg 1 [8:15] | K10K11 | x(0,N) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] |
// Reg 1 [16:23] | K12K13 | x(0,N) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] |
// Reg 1 [24:31] | K14K15 | x(0,N) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] |
// Reg 2 [0:7] | K16K17 | x(0,N) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] |
// Reg 2 [8:15] | K18K19 | x(0,N) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] |
// Reg 2 [16:23] | K20K21 | x(0,N) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] |
// Reg 2 [24:31] | K22K23 | x(0,N) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] |
// Reg 3 [0:7] | K24K25 | x(0,N) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] |
// Reg 3 [8:15] | K26K27 | x(0,N) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] |
// Reg 3 [16:23] | K28K29 | x(0,N) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] |
// Reg 3 [24:31] | K30K31 | x(0,N) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] |
// clang-format on
static constexpr uint32_t VW = vectorSize(BFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where to start
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row
threadIdx.x % BLOCK_N); // Col
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N), // Row
threadIdx.x % BLOCK_N); // Col
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
@@ -766,15 +900,24 @@ template <typename AType,
typename AccType,
int32_t BLOCK_M,
int32_t BLOCK_N,
int32_t BLOCK_K>
int32_t BLOCK_K,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void matmul(const AType* a, const BType* b, CType* c)
{
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
@@ -786,10 +929,23 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto fragAcc = AccumFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
}
else
{
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
}
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
printf("This layout is not implemented\n");
}
else
{
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
}
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
@@ -801,8 +957,14 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
if constexpr(is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
}
else
{
store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
}
}
template <typename AType,
@@ -813,7 +975,10 @@ template <typename AType,
int32_t BLOCK_M,
int32_t BLOCK_N,
int32_t BLOCK_K,
int32_t BLOCK_X>
int32_t BLOCK_X,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
{
@@ -821,8 +986,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
@@ -838,13 +1009,27 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
auto fragXb = BScaleFragT{};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
fragA =
load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
}
else
{
printf("This layout is not implemented\n");
}
// B = col major, BLOCK_K x BLOCK_N
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
if constexpr(is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
printf("This layout is not implemented\n");
}
else
{
fragB =
load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
}
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
@@ -860,8 +1045,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
if constexpr(is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
}
else
{
store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}(c, fragC);
}
}
/**
@@ -993,8 +1184,7 @@ struct TestMXMFMA
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(
GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/64
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
// NOTE: not all numbers are representable in FP8, BF8, etc.
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
@@ -1012,11 +1202,9 @@ struct TestMXMFMA
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
@@ -1026,6 +1214,14 @@ struct TestMXMFMA
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
break;
case 4:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1., 1.});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1., 1.});
b_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
@@ -1207,6 +1403,11 @@ struct TestMFMA
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
break;
case 4:
// FP4 values case
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5});
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});