fix the correctness issue

This commit is contained in:
aska-0096
2025-05-26 09:29:36 +00:00
parent 4a3205f94a
commit d1d56e89ef
5 changed files with 133 additions and 113 deletions

View File

@@ -253,6 +253,22 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
using AScaleLayout = Row;
using BScaleLayout = Col;
const auto APackedSize = []() {
if constexpr(ck::is_same_v<ck::remove_cvref_t<ADataType>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<ADataType>, ck::f4x2_pk_t>)
return 2;
else
return 1;
}();
const auto BPackedSize = []() {
if constexpr(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::f4x2_pk_t>)
return 2;
else
return 1;
}();
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
@@ -362,26 +378,12 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
int NPerXdl = 16; // Fixed 16
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl);
#endif
printf("a:\n");
for(ck::index_t i = 0; i < M; i++)
{
for(ck::index_t j = 0; j < K; j += 2)
{
printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k(i, j)));
if(j % 32 == 31)
{
printf("\n");
}
}
printf("\n");
}
// printf("b:\n");
// for(ck::index_t i = 0; i < N; i++)
// printf("a:\n");
// for(ck::index_t i = 0; i < M; i++)
// {
// for(ck::index_t j = 0; j < K; j += 2)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n(j, i)));
// printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k(i, j)));
// if(j % 32 == 31)
// {
// printf("\n");
@@ -389,6 +391,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// }
// printf("\n");
// }
// printf("b:\n");
// for(ck::index_t i = 0; i < N; i++)
// {
// for(ck::index_t j = 0; j < K; j += 2)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_preshuffled(j, i)));
// if(j % 128 == 126)
// {
// printf("\n");
// }
// }
// // printf("\n");
// }
// printf("b_scale:\n");
// for(ck::index_t i = 0; i < N; i++)
// {
@@ -547,9 +563,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
std::size_t num_btype = sizeof(ADataType) * M * K / APackedSize+
sizeof(BDataType) * K * N / BPackedSize+
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
sizeof(XDataType) * M * K / ScaleBlockSize +
sizeof(XDataType) * N * K / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

View File

@@ -49,24 +49,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
64, // BlockSize: Thread block size
64, // MPerBlock
64, // NPerBlock
256, // BlockSize: Thread block size
128, // MPerBlock
256, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
8, // MXdlPerWave
4, // NXdlPerWave
S<8, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
@@ -75,7 +75,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer

View File

@@ -382,29 +382,19 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<MRepeat / MXdlPack>{},
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
I1,
Number<MXdlPack>{},
Number<KRepeat>{},
Number<KPack>{}),
make_tuple(Number<KPack * MXdlPack>{},
Number<KRepeat * MRepeat * KPack>{},
Number<MRepeat * KPack>{},
Number<KPack>{},
I1));
Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<NRepeat / NXdlPack>{},
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat / NXdlPack>{},
I1,
Number<KRepeat>{},
Number<NXdlPack>{},
Number<KPack>{}),
make_tuple(Number<KPack * NXdlPack>{},
Number<KRepeat * NRepeat * KPack>{},
Number<NRepeat * KPack>{},
Number<KPack>{},
I1));
Number<KRepeat>{},
Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ =

View File

@@ -446,30 +446,82 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Global prefetch 2
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
#if defined(__gfx950__) && 0
printf(
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n",
"Tid: %02d, NRepeat0, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, NRepeat0, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, NRepeat1, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, NRepeat1, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<0>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<1>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<2>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<3>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<4>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<5>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<6>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+0>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+1>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+2>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+3>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+4>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+5>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+6>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+7>{})))
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+7>{})))
);
#endif
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
@@ -996,11 +1048,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
#if defined(__gfx950__) && 0
printf("Tid: %02d, ik, im, in = %d, %d, %d\n"
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n",
printf("Tid: %02d, ik0, im0, in0, ikxdl, imxdl, inxdl = %d, %d, %d, %d, %d, %d\n"
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
ikxdl.value, imxdl.value, inxdl.value,
k0.value, m0.value, n0.value, ikxdl.value, imxdl.value, inxdl.value,
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<0>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<1>{}))),
@@ -1010,7 +1062,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<5>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<6>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+0>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+1>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+2>{}))),
@@ -1018,13 +1069,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+4>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+5>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+6>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+7>{})))
);
printf("Tid: %02d, ik, im, in = %d, %d, %d\n"
"Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
ikxdl.value, imxdl.value, inxdl.value,
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<1>{}))),
@@ -1034,7 +1079,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<8+0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<8+1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<8+2>{}))),
@@ -1080,8 +1124,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages *MXdlPack) / MXdlPack)
{
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
@@ -1112,28 +1156,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
printf(
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<0>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<1>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<2>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<3>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<4>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<5>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<6>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+0>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+1>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+2>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+3>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+4>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+5>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+6>{}))),
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+7>{})))
);
}
});
}
@@ -1143,16 +1165,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Order: 1 0 3 2 4
static constexpr auto ARegBuf = 2;
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<ARegBuf>{},
make_naive_tensor_descriptor_packed(make_tuple(Number<ARegBuf>{},
I1,
Number<MXdlPack>{},
Number<KRepeat>{},
Number<KPack>{}),
make_tuple(Number<KRepeat * MXdlPack* KPack>{},
Number<ARegBuf * MXdlPack * KRepeat * KPack>{},
Number<KPack>{},
Number<MXdlPack*KPack>{},
I1));
Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,

View File

@@ -422,13 +422,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber),
make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber,
NXdlPack * K0 * NkSwizzleNumber,
K0 * NkSwizzleNumber,
NkSwizzleNumber,
I1));
return make_naive_tensor_descriptor_packed(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
@@ -1903,7 +1898,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<1, 2, 3, 0, 4>,
Sequence<0, 1, 2, 3, 4>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,