mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix k0 calculate method
This commit is contained in:
@@ -98,13 +98,14 @@ struct MultiplyMultiply
|
||||
}
|
||||
};
|
||||
|
||||
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
|
||||
template<typename DataType>
|
||||
void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl, int Knew)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
int K0 = Knew / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
@@ -147,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, B0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -304,8 +305,44 @@ int main(int argc, char* argv[])
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
b0_preshuffled.SetZero();
|
||||
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
|
||||
#if 0
|
||||
{ //test shuffle result
|
||||
auto ouput_matirx=[](auto data, ck::index_t N0, ck::index_t K0){
|
||||
std::cout << std::endl;
|
||||
int ii = 0;
|
||||
for(int n = 0; n < N0; n++)
|
||||
{
|
||||
std::cout << ii++ << " line: ";
|
||||
for(int k = 0; k < K0; k++)
|
||||
{
|
||||
std::cout << data(k,n) << " ";
|
||||
// std::cout << ck::type_convert<float>(data.mData[n*K0 + k]) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<int> b0_k_n2(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<int> b0_preshuffled2(
|
||||
f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size
|
||||
// int nCount = 0;
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
b0_k_n2(k,n) = k;//nCount++;
|
||||
}
|
||||
}
|
||||
ouput_matirx(b0_k_n2, N, K);
|
||||
b0_preshuffled2.SetZero();
|
||||
preShuffleBuffer(b0_k_n2.mData.data(), b0_preshuffled2.mData.data(), N, K, NPerXdl, Knew);
|
||||
std::cout << "after shuffle" << std::endl;
|
||||
ouput_matirx(b0_preshuffled2, N, Knew);
|
||||
}
|
||||
#endif
|
||||
|
||||
b0_preshuffled.SetZero(); //must set to zero
|
||||
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl, Knew);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
|
||||
@@ -1162,7 +1162,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
|
||||
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user