mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Remove leftovers
This commit is contained in:
@@ -520,9 +520,6 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
|
||||
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
@@ -591,9 +588,6 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
|
||||
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
@@ -663,9 +657,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
|
||||
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
@@ -731,9 +722,6 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
// printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
|
||||
// printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
|
||||
@@ -280,11 +280,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
row_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT =
|
||||
vector_type<ARawT,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
|
||||
|
||||
constexpr index_t num_chunks =
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 1 : 2);
|
||||
@@ -891,7 +888,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
using AFragT =
|
||||
vector_type<AType,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using BFragT =
|
||||
vector_type<BType,
|
||||
BLOCK_K * BLOCK_N / WAVE_SIZE /
|
||||
@@ -907,48 +904,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
auto fragAcc = AccumFragT{0};
|
||||
|
||||
// Load the inputs.
|
||||
// A = col major, BLOCK_M x BLOCK_K
|
||||
// fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
fragA = load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
// B = col major, BLOCK_K x BLOCK_N
|
||||
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
|
||||
|
||||
// printf("&&&&&&& %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u\n",
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<0>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<1>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<2>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<3>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<4>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<5>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<6>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<7>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<8>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<9>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<10>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<11>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<12>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<13>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<14>{}].data),
|
||||
// uint32_t(fragA.template AsType<AType>()[Number<15>{}].data));
|
||||
|
||||
// printf("$$$$$$ %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u\n",
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<0>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<1>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<2>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<3>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<4>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<5>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<6>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<7>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<8>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<9>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<10>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<11>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<12>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<13>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<14>{}].data),
|
||||
// uint32_t(fragB.template AsType<BType>()[Number<15>{}].data));
|
||||
|
||||
// Matrix multiply-accumulate using MFMA units
|
||||
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
||||
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
|
||||
|
||||
Reference in New Issue
Block a user