diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 4bb38a0c16..b2e615b9d8 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -187,11 +187,11 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - using ARawT = typename scalar_type::type; - using AScalarFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using ARawT = typename scalar_type::type; + using AScalarFragT = typename vector_type< + ARawT, + BLOCK_M * BLOCK_K / WAVE_SIZE / + (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)>::type; AScalarFragT fragA{}; @@ -319,8 +319,9 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; - using ARawT = typename scalar_type::type; - using AScalarChunkT = vector_type::vector_size / num_chunks>::type; + using ARawT = typename scalar_type::type; + using AScalarChunkT = + typename vector_type::vector_size / num_chunks>::type; union { @@ -544,8 +545,9 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - using BRawT = typename scalar_type::type; - using BScalarChunkT = vector_type::vector_size / num_chunks>::type; + using BRawT = typename scalar_type::type; + using BScalarChunkT = + typename vector_type::vector_size / num_chunks>::type; union { @@ -780,7 +782,7 @@ struct store_C_col_major // we can vector store 4 contiguous elements at a time. using CRawT = typename scalar_type::type; - using CScalarFragT = vector_type::type; + using CScalarFragT = typename vector_type::type; union { CFragT frag; @@ -940,12 +942,14 @@ __global__ void matmul(const packed_type_t* a, const packed_type_t assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + typename vector_type::type; + using BFragT = + typename vector_type::type; - using CFragT = vector_type::type; + using CFragT = typename vector_type::type; using AccumFragT = vector_type; - using RawAccumFragT = vector_type::type; + using RawAccumFragT = typename vector_type::type; // Create frags auto fragA = AFragT{}; @@ -1019,14 +1023,16 @@ __global__ void matmul(const packed_type_t* a, assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + typename vector_type::type; + using BFragT = + typename vector_type::type; - using CFragT = vector_type::type; + using CFragT = typename vector_type::type; using AccumFragT = vector_type; - using RawAccumFragT = vector_type::type; - using AScaleFragT = vector_type::type; - using BScaleFragT = vector_type::type; + using RawAccumFragT = typename vector_type::type; + using AScaleFragT = typename vector_type::type; + using BScaleFragT = typename vector_type::type; // Create frags auto fragA = AFragT{};