mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
fix the correctness issue
This commit is contained in:
@@ -253,6 +253,22 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
using AScaleLayout = Row;
|
||||
using BScaleLayout = Col;
|
||||
|
||||
const auto APackedSize = []() {
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<ADataType>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<ADataType>, ck::f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
const auto BPackedSize = []() {
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ||
|
||||
ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
|
||||
|
||||
@@ -362,26 +378,12 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
int NPerXdl = 16; // Fixed 16
|
||||
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl);
|
||||
#endif
|
||||
printf("a:\n");
|
||||
for(ck::index_t i = 0; i < M; i++)
|
||||
{
|
||||
for(ck::index_t j = 0; j < K; j += 2)
|
||||
{
|
||||
printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k(i, j)));
|
||||
if(j % 32 == 31)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// printf("b:\n");
|
||||
// for(ck::index_t i = 0; i < N; i++)
|
||||
// printf("a:\n");
|
||||
// for(ck::index_t i = 0; i < M; i++)
|
||||
// {
|
||||
// for(ck::index_t j = 0; j < K; j += 2)
|
||||
// {
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n(j, i)));
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k(i, j)));
|
||||
// if(j % 32 == 31)
|
||||
// {
|
||||
// printf("\n");
|
||||
@@ -389,6 +391,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// printf("b:\n");
|
||||
// for(ck::index_t i = 0; i < N; i++)
|
||||
// {
|
||||
// for(ck::index_t j = 0; j < K; j += 2)
|
||||
// {
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_preshuffled(j, i)));
|
||||
// if(j % 128 == 126)
|
||||
// {
|
||||
// printf("\n");
|
||||
// }
|
||||
// }
|
||||
// // printf("\n");
|
||||
// }
|
||||
// printf("b_scale:\n");
|
||||
// for(ck::index_t i = 0; i < N; i++)
|
||||
// {
|
||||
@@ -547,9 +563,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// partial sums(K/ScaleBlockSize)]
|
||||
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K / APackedSize+
|
||||
sizeof(BDataType) * K * N / BPackedSize+
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
|
||||
sizeof(XDataType) * M * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * N * K / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
|
||||
@@ -49,24 +49,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
64, // BlockSize: Thread block size
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
8, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<8, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<8, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
@@ -75,7 +75,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
|
||||
@@ -382,29 +382,19 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
// Read buffer + Compute buffer
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor(make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
I1,
|
||||
Number<MXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<KPack>{}),
|
||||
make_tuple(Number<KPack * MXdlPack>{},
|
||||
Number<KRepeat * MRepeat * KPack>{},
|
||||
Number<MRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
I1));
|
||||
Number<KPack>{}));
|
||||
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor(make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
I1,
|
||||
Number<KRepeat>{},
|
||||
Number<NXdlPack>{},
|
||||
Number<KPack>{}),
|
||||
make_tuple(Number<KPack * NXdlPack>{},
|
||||
Number<KRepeat * NRepeat * KPack>{},
|
||||
Number<NRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
I1));
|
||||
Number<KRepeat>{},
|
||||
Number<KPack>{}));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ =
|
||||
|
||||
@@ -446,30 +446,82 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
#if defined(__gfx950__) && 0
|
||||
printf(
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n",
|
||||
"Tid: %02d, NRepeat0, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, NRepeat0, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, NRepeat1, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, NRepeat1, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8+7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+7>{})))
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16+8+7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+8+7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32+16+8+7>{})))
|
||||
);
|
||||
|
||||
#endif
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -996,11 +1048,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
});
|
||||
|
||||
#if defined(__gfx950__) && 0
|
||||
printf("Tid: %02d, ik, im, in = %d, %d, %d\n"
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n",
|
||||
printf("Tid: %02d, ik0, im0, in0, ikxdl, imxdl, inxdl = %d, %d, %d, %d, %d, %d\n"
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
ikxdl.value, imxdl.value, inxdl.value,
|
||||
k0.value, m0.value, n0.value, ikxdl.value, imxdl.value, inxdl.value,
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<1>{}))),
|
||||
@@ -1010,7 +1062,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+2>{}))),
|
||||
@@ -1018,13 +1069,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+7>{})))
|
||||
);
|
||||
printf("Tid: %02d, ik, im, in = %d, %d, %d\n"
|
||||
"Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, B %02x %02x %02x %02x %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
ikxdl.value, imxdl.value, inxdl.value,
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_vec.template AsType<ComputeTypeA>()(Number<8+7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<1>{}))),
|
||||
@@ -1034,7 +1079,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_vec.template AsType<ComputeTypeA>()(Number<8+2>{}))),
|
||||
@@ -1080,8 +1124,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
if constexpr(m0.value < (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
|
||||
});
|
||||
if constexpr(m0.value < (MRepeat - LocalPrefetchStages *MXdlPack) / MXdlPack)
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
@@ -1112,28 +1156,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
});
|
||||
});
|
||||
});
|
||||
printf(
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, A %02x %02x %02x %02x %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(a_thread_buf(Number<8+7>{})))
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1143,16 +1165,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
// Order: 1 0 3 2 4
|
||||
static constexpr auto ARegBuf = 2;
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor(make_tuple(Number<ARegBuf>{},
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<ARegBuf>{},
|
||||
I1,
|
||||
Number<MXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<KPack>{}),
|
||||
make_tuple(Number<KRepeat * MXdlPack* KPack>{},
|
||||
Number<ARegBuf * MXdlPack * KRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
Number<MXdlPack*KPack>{},
|
||||
I1));
|
||||
Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeTypeA,
|
||||
|
||||
@@ -422,13 +422,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber,
|
||||
NXdlPack * K0 * NkSwizzleNumber,
|
||||
K0 * NkSwizzleNumber,
|
||||
NkSwizzleNumber,
|
||||
I1));
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
@@ -1903,7 +1898,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 3, 0, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
|
||||
Reference in New Issue
Block a user