mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
add shuffle padded function
This commit is contained in:
@@ -98,7 +98,7 @@ struct MultiplyMultiply
|
||||
}
|
||||
};
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl, int Knew)
|
||||
{
|
||||
int KPack = 16;
|
||||
@@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, B0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, B0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -243,7 +243,8 @@ int main(int argc, char* argv[])
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
auto Knew = (K + 64 - 1) / 64 * 64;
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto Knew = device_op.GetPreShufflePadded(K);
|
||||
auto StrideBnew = Knew;
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
@@ -301,7 +302,6 @@ int main(int argc, char* argv[])
|
||||
constexpr auto I0 = ck::Number<0>{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
b0_preshuffled.SetZero(); //must set to zero
|
||||
b0_preshuffled.SetZero(); // must set to zero
|
||||
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl, Knew);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
@@ -138,7 +138,8 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
virtual int GetPreShufflePadded(int K) = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -140,6 +140,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
int GetPreShufflePadded(int K) override { return (K + 64 - 1) / 64 * 64; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
|
||||
@@ -1161,8 +1161,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
|
||||
const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled(
|
||||
problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user