fix test_mx_mfma errors (#2614)

[ROCm/composable_kernel commit: fb96b49666]
This commit is contained in:
Illia Silin
2025-08-04 11:43:47 -07:00
committed by GitHub
parent 0bc4627b73
commit 7e796b7861

View File

@@ -187,11 +187,11 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M);
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M);
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT =
vector_type<ARawT,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = typename vector_type<
ARawT,
BLOCK_M * BLOCK_K / WAVE_SIZE /
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
AScalarFragT fragA{};
@@ -319,8 +319,9 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
using ARawT = typename scalar_type<AFragT>::type;
using AScalarChunkT = vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
using ARawT = typename scalar_type<AFragT>::type;
using AScalarChunkT =
typename vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
union
{
@@ -544,8 +545,9 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
using BRawT = typename scalar_type<BFragT>::type;
using BScalarChunkT = vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
using BRawT = typename scalar_type<BFragT>::type;
using BScalarChunkT =
typename vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
union
{
@@ -780,7 +782,7 @@ struct store_C_col_major<CType, CFragT, 32, 32>
// we can vector store 4 contiguous elements at a time.
using CRawT = typename scalar_type<CFragT>::type;
using CScalarFragT = vector_type<CRawT, VW>::type;
using CScalarFragT = typename vector_type<CRawT, VW>::type;
union
{
CFragT frag;
@@ -940,12 +942,14 @@ __global__ void matmul(const packed_type_t<AType>* a, const packed_type_t<BType>
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using AFragT =
typename vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT =
typename vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using CFragT = typename vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using RawAccumFragT = typename vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
// Create frags
auto fragA = AFragT{};
@@ -1019,14 +1023,16 @@ __global__ void matmul(const packed_type_t<AType>* a,
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using AFragT =
typename vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
using BFragT =
typename vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using CFragT = typename vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AScaleFragT = vector_type<ScaleType, 1>::type;
using BScaleFragT = vector_type<ScaleType, 1>::type;
using RawAccumFragT = typename vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AScaleFragT = typename vector_type<ScaleType, 1>::type;
using BScaleFragT = typename vector_type<ScaleType, 1>::type;
// Create frags
auto fragA = AFragT{};