add shuffle padded function

This commit is contained in:
qin letao
2025-02-24 04:00:43 +00:00
parent 0a32fb4765
commit 782712a78f
4 changed files with 10 additions and 8 deletions

View File

@@ -98,7 +98,7 @@ struct MultiplyMultiply
}
};
template<typename DataType>
template <typename DataType>
void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl, int Knew)
{
int KPack = 16;
@@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, B0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, B0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -243,7 +243,8 @@ int main(int argc, char* argv[])
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto Knew = (K + 64 - 1) / 64 * 64;
auto device_op = DeviceOpInstance{};
auto Knew = device_op.GetPreShufflePadded(K);
auto StrideBnew = Knew;
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
@@ -301,7 +302,6 @@ int main(int argc, char* argv[])
constexpr auto I0 = ck::Number<0>{};
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
}
#endif
b0_preshuffled.SetZero(); //must set to zero
b0_preshuffled.SetZero(); // must set to zero
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl, Knew);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());

View File

@@ -138,7 +138,8 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
virtual int GetPreShuffleParameters() = 0;
virtual int GetPreShufflePadded(int K) = 0;
};
} // namespace device

View File

@@ -140,6 +140,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using Argument = typename GridwiseGemm::Argument;
int GetPreShuffleParameters() override { return NPerXDL; }
int GetPreShufflePadded(int K) override { return (K + 64 - 1) / 64 * 64; }
// Invoker
struct Invoker : public BaseInvoker

View File

@@ -1161,8 +1161,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled(
problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);