diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ab23c1a128..da1cd7e92c 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -520,8 +520,8 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> using arg_type = int32x8_t; - printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); + printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); + printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( @@ -591,8 +591,8 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> using arg_type = int32x8_t; - printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); + printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); + printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( @@ -663,8 +663,8 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; - printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); + printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); + printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( @@ -731,8 +731,8 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; - printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); - printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); + printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]); + printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]); reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 2adc47d1b5..4703142b2e 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -397,7 +397,8 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) static constexpr int32_t WAVE_SIZE = 64; // Here we want to load from cols of B in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + static constexpr uint32_t chunk_size = + 16 / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1); // each chunk is separated by an offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64