diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 57875e6cc3..d00c2ad554 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -520,9 +520,6 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> using arg_type = int32x8_t; - // printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - // printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, @@ -591,9 +588,6 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> using arg_type = int32x8_t; - // printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - // printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, @@ -663,9 +657,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; - // printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - // printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, @@ -731,9 +722,6 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; - // printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - // printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index ccf86761e7..78b43402de 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -280,11 +280,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) row_major(majorStepCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - using ARawT = typename scalar_type::type; - using AScalarFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; constexpr index_t num_chunks = (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); @@ -891,7 +888,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) using AFragT = vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)>::type; using BFragT = vector_type(a); fragA = load_A_row_major(a); - // B = col major, BLOCK_K x BLOCK_N fragB = load_B_col_major(b); - // printf("&&&&&&& %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u\n", - // uint32_t(fragA.template AsType()[Number<0>{}].data), - // uint32_t(fragA.template AsType()[Number<1>{}].data), - // uint32_t(fragA.template AsType()[Number<2>{}].data), - // uint32_t(fragA.template AsType()[Number<3>{}].data), - // uint32_t(fragA.template AsType()[Number<4>{}].data), - // uint32_t(fragA.template AsType()[Number<5>{}].data), - // uint32_t(fragA.template AsType()[Number<6>{}].data), - // uint32_t(fragA.template AsType()[Number<7>{}].data), - // uint32_t(fragA.template AsType()[Number<8>{}].data), - // uint32_t(fragA.template AsType()[Number<9>{}].data), - // uint32_t(fragA.template AsType()[Number<10>{}].data), - // uint32_t(fragA.template AsType()[Number<11>{}].data), - // uint32_t(fragA.template AsType()[Number<12>{}].data), - // uint32_t(fragA.template AsType()[Number<13>{}].data), - // uint32_t(fragA.template AsType()[Number<14>{}].data), - // uint32_t(fragA.template AsType()[Number<15>{}].data)); - - // printf("$$$$$$ %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u, %u\n", - // uint32_t(fragB.template AsType()[Number<0>{}].data), - // uint32_t(fragB.template AsType()[Number<1>{}].data), - // uint32_t(fragB.template AsType()[Number<2>{}].data), - // uint32_t(fragB.template AsType()[Number<3>{}].data), - // uint32_t(fragB.template AsType()[Number<4>{}].data), - // uint32_t(fragB.template AsType()[Number<5>{}].data), - // uint32_t(fragB.template AsType()[Number<6>{}].data), - // uint32_t(fragB.template AsType()[Number<7>{}].data), - // uint32_t(fragB.template AsType()[Number<8>{}].data), - // uint32_t(fragB.template AsType()[Number<9>{}].data), - // uint32_t(fragB.template AsType()[Number<10>{}].data), - // uint32_t(fragB.template AsType()[Number<11>{}].data), - // uint32_t(fragB.template AsType()[Number<12>{}].data), - // uint32_t(fragB.template AsType()[Number<13>{}].data), - // uint32_t(fragB.template AsType()[Number<14>{}].data), - // uint32_t(fragB.template AsType()[Number<15>{}].data)); - // Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N mfma_type_selector{}(fragA, fragB, fragAcc);