mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add tests
This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
#include "mx_mfma_op.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::f4_t;
|
||||
using ck::f4x2_pk_t;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
@@ -122,3 +124,19 @@ TEST(MXMFMA, MXFP8MFMA32x32x64)
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto pass =
|
||||
run_mxmfma_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto pass =
|
||||
run_mxmfma_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -319,8 +320,8 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
|
||||
// Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) |
|
||||
// Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) |
|
||||
// clang-format on
|
||||
static constexpr uint32_t VW = vectorSize(AFragT{});
|
||||
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
||||
const uint32_t VW = vectorSize(AFragT{});
|
||||
// static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 1 element
|
||||
@@ -487,8 +488,8 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
|
||||
// Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) |
|
||||
|
||||
// clang-format on
|
||||
static constexpr uint32_t VW = vectorSize(BFragT{});
|
||||
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
||||
const uint32_t VW = vectorSize(BFragT{});
|
||||
// static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 1 element
|
||||
@@ -800,8 +801,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
|
||||
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AFragT =
|
||||
vector_type<AType,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using BFragT =
|
||||
vector_type<BType,
|
||||
BLOCK_K * BLOCK_N / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
|
||||
Reference in New Issue
Block a user