Remove leftovers

This commit is contained in:
Rostyslav Geyyer
2025-04-30 18:59:06 +00:00
parent 44f47ac7c5
commit 4ec936befd
2 changed files with 3 additions and 57 deletions

View File

@@ -520,9 +520,6 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
using arg_type = int32x8_t;
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
@@ -591,9 +588,6 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
using arg_type = int32x8_t;
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
@@ -663,9 +657,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
@@ -731,9 +722,6 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},

View File

@@ -280,11 +280,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
row_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT =
vector_type<ARawT,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
constexpr index_t num_chunks =
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
@@ -891,7 +888,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
using AFragT =
vector_type<AType,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using BFragT =
vector_type<BType,
BLOCK_K * BLOCK_N / WAVE_SIZE /
@@ -907,48 +904,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto fragAcc = AccumFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
// fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
// printf("&&&&&&& %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u\n",
// uint32_t(fragA.template AsType<AType>()[Number<0>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<1>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<2>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<3>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<4>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<5>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<6>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<7>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<8>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<9>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<10>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<11>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<12>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<13>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<14>{}].data),
// uint32_t(fragA.template AsType<AType>()[Number<15>{}].data));
// printf("$$$$$$ %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u\n",
// uint32_t(fragB.template AsType<BType>()[Number<0>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<1>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<2>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<3>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<4>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<5>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<6>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<7>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<8>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<9>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<10>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<11>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<12>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<13>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<14>{}].data),
// uint32_t(fragB.template AsType<BType>()[Number<15>{}].data));
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);