fix k0 calculate method

This commit is contained in:
qin letao
2025-02-24 03:37:44 +00:00
parent 23b2fb0f3e
commit 0a32fb4765
2 changed files with 43 additions and 6 deletions

View File

@@ -98,13 +98,14 @@ struct MultiplyMultiply
}
};
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
template<typename DataType>
void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl, int Knew)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
int K0 = Knew / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
@@ -147,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, B0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -304,8 +305,44 @@ int main(int argc, char* argv[])
int NPerXdl = device_op.GetPreShuffleParameters();
b0_preshuffled.SetZero();
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
#if 0
{ //test shuffle result
auto ouput_matirx=[](auto data, ck::index_t N0, ck::index_t K0){
std::cout << std::endl;
int ii = 0;
for(int n = 0; n < N0; n++)
{
std::cout << ii++ << " line: ";
for(int k = 0; k < K0; k++)
{
std::cout << data(k,n) << " ";
// std::cout << ck::type_convert<float>(data.mData[n*K0 + k]) << " ";
}
std::cout << std::endl;
}
};
Tensor<int> b0_k_n2(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<int> b0_preshuffled2(
f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size
// int nCount = 0;
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
b0_k_n2(k,n) = k;//nCount++;
}
}
ouput_matirx(b0_k_n2, N, K);
b0_preshuffled2.SetZero();
preShuffleBuffer(b0_k_n2.mData.data(), b0_preshuffled2.mData.data(), N, K, NPerXdl, Knew);
std::cout << "after shuffle" << std::endl;
ouput_matirx(b0_preshuffled2, N, Knew);
}
#endif
b0_preshuffled.SetZero(); //must set to zero
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl, Knew);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());

View File

@@ -1162,7 +1162,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);