mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
fix test_mx_mfma errors (#2614)
[ROCm/composable_kernel commit: fb96b49666]
This commit is contained in:
@@ -187,11 +187,11 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
|
||||
auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M);
|
||||
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M);
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT =
|
||||
vector_type<ARawT,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT = typename vector_type<
|
||||
ARawT,
|
||||
BLOCK_M * BLOCK_K / WAVE_SIZE /
|
||||
(ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1)>::type;
|
||||
|
||||
AScalarFragT fragA{};
|
||||
|
||||
@@ -319,8 +319,9 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
// Flatten to 1D row_major offsets.
|
||||
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
||||
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarChunkT = vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarChunkT =
|
||||
typename vector_type<ARawT, scalar_type<AFragT>::vector_size / num_chunks>::type;
|
||||
|
||||
union
|
||||
{
|
||||
@@ -544,8 +545,9 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
|
||||
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
|
||||
|
||||
using BRawT = typename scalar_type<BFragT>::type;
|
||||
using BScalarChunkT = vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
|
||||
using BRawT = typename scalar_type<BFragT>::type;
|
||||
using BScalarChunkT =
|
||||
typename vector_type<BRawT, scalar_type<BFragT>::vector_size / num_chunks>::type;
|
||||
|
||||
union
|
||||
{
|
||||
@@ -780,7 +782,7 @@ struct store_C_col_major<CType, CFragT, 32, 32>
|
||||
|
||||
// we can vector store 4 contiguous elements at a time.
|
||||
using CRawT = typename scalar_type<CFragT>::type;
|
||||
using CScalarFragT = vector_type<CRawT, VW>::type;
|
||||
using CScalarFragT = typename vector_type<CRawT, VW>::type;
|
||||
union
|
||||
{
|
||||
CFragT frag;
|
||||
@@ -940,12 +942,14 @@ __global__ void matmul(const packed_type_t<AType>* a, const packed_type_t<BType>
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
|
||||
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
|
||||
using AFragT =
|
||||
typename vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
|
||||
using BFragT =
|
||||
typename vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
|
||||
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using CFragT = typename vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using RawAccumFragT = typename vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
|
||||
// Create frags
|
||||
auto fragA = AFragT{};
|
||||
@@ -1019,14 +1023,16 @@ __global__ void matmul(const packed_type_t<AType>* a,
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT = vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
|
||||
using BFragT = vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
|
||||
using AFragT =
|
||||
typename vector_type<PackedAType, BLOCK_M * BLOCK_K / WAVE_SIZE / packed_size_a>::type;
|
||||
using BFragT =
|
||||
typename vector_type<PackedBType, BLOCK_K * BLOCK_N / WAVE_SIZE / packed_size_b>::type;
|
||||
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using CFragT = typename vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AScaleFragT = vector_type<ScaleType, 1>::type;
|
||||
using BScaleFragT = vector_type<ScaleType, 1>::type;
|
||||
using RawAccumFragT = typename vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AScaleFragT = typename vector_type<ScaleType, 1>::type;
|
||||
using BScaleFragT = typename vector_type<ScaleType, 1>::type;
|
||||
|
||||
// Create frags
|
||||
auto fragA = AFragT{};
|
||||
|
||||
Reference in New Issue
Block a user