mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
fix xdl transpose.
This commit is contained in:
@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
32, 128, 128,
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
32, 32,
|
||||
1, 1,
|
||||
@@ -306,6 +306,34 @@ int main(int argc, char* argv[])
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{}); // 1
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); // 1
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{}); // 1
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); // 1
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
@@ -383,7 +411,7 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// printf the input tensor
|
||||
// printf a tensor
|
||||
printf("a0_t_k_k: \n");
|
||||
@@ -394,7 +422,7 @@ int main(int argc, char* argv[])
|
||||
printf("topk: %d: ", tk);
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k)));
|
||||
printf("%.1f ", ck::type_convert<float>(a0_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -409,26 +437,26 @@ int main(int argc, char* argv[])
|
||||
printf("topk: %d: ", tk);
|
||||
for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
|
||||
printf("%.1f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// printf b tensor
|
||||
// printf("b0_e_n_k: \n");
|
||||
// for (int e=0; e < experts; ++e)
|
||||
// {
|
||||
// for (int k=0; k < K; ++k)
|
||||
// {
|
||||
// printf("expert: %d: ", e);
|
||||
// for (int n=0; n < N; ++n)
|
||||
// {
|
||||
// printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// }
|
||||
printf("b0_e_n_k: \n");
|
||||
for (int e=0; e < experts; ++e)
|
||||
{
|
||||
for (int k=0; k < K; ++k)
|
||||
{
|
||||
printf("expert: %d: ", e);
|
||||
for (int n=0; n < N; ++n)
|
||||
{
|
||||
printf("%.1f ", ck::type_convert<float>(b0_e_n_k(e, k, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// printf b scale tensor
|
||||
printf("b1_e_n_k: \n");
|
||||
@@ -439,7 +467,7 @@ int main(int argc, char* argv[])
|
||||
printf("expert: %d: ", e);
|
||||
for(int n = 0; n < (N + Scale_Block_N - 1) / Scale_Block_N; ++n)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
|
||||
printf("%.1f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
@@ -426,6 +426,12 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
});
|
||||
});
|
||||
|
||||
// printf("blockIdx.y = %d, blockIdx.x = %d, threadIdx.x = %d, c_scale_thread_buf = <%f>\n",
|
||||
// blockIdx.y,
|
||||
// blockIdx.x,
|
||||
// threadIdx.x,
|
||||
// c_scale_thread_buf[Number<0>{}]);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
@@ -552,6 +558,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
|
||||
@@ -568,6 +568,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ __device__ Problem(index_t NumTokens_,
|
||||
@@ -1360,7 +1362,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
// constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
auto a_thread_offset =
|
||||
get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
|
||||
@@ -1380,7 +1382,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
|
||||
const index_t fused_token =
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl + a_thread_offset];
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
@@ -1462,29 +1464,67 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// // print C
|
||||
// printf("tid: %d, blkid: %d, "
|
||||
// "c_thread_buf = <%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f,
|
||||
// %1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n", get_thread_local_1d_id(), block_m_id,
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<0>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<1>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<2>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<3>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<4>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<5>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<6>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<7>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<8>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<9>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<10>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<11>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<12>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<13>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<14>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<15>{}]);
|
||||
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// transposed XDL
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
|
||||
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -1493,24 +1533,24 @@ struct GridwiseMoeGemmBlockScale
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
M2)), // M2 = MPerXdl
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
N2, // N2 * N3 * N4 = NPerXdl
|
||||
N3,
|
||||
N4))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -1520,56 +1560,56 @@ struct GridwiseMoeGemmBlockScale
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
N2,
|
||||
I1,
|
||||
N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
n_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I3],
|
||||
n_thread_data_on_block_idx[I4]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
@@ -1665,16 +1705,16 @@ struct GridwiseMoeGemmBlockScale
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, 1, N2, 1, N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
N2,
|
||||
1,
|
||||
N4>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -1721,10 +1761,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
|
||||
Reference in New Issue
Block a user