fix moe i4 example bug (#2139)

[ROCm/composable_kernel commit: 83394e40d2]
This commit is contained in:
lalala-sh
2025-04-29 00:49:31 +08:00
committed by GitHub
parent db016cf6da
commit ffc41f64fd

View File

@@ -233,7 +233,7 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 1, 1};
ck::index_t KBatch = 1;
@@ -266,7 +266,8 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]}));
Tensor<D1DataType> d1_e_n(
HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(