diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index d46f234f98..b68488f5ce 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -226,7 +226,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) static constexpr int32_t WAVE_SIZE = 64; // Here we want to load from rows of A in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + static constexpr uint32_t chunk_size = + 16 / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1); // each chunk is separated by offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;