fix update

This commit is contained in:
mtgu0705
2025-05-09 11:04:39 +08:00
parent 11f386108e
commit f2a474e2e9
5 changed files with 68 additions and 81 deletions

View File

@@ -7,6 +7,7 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
@@ -21,12 +22,18 @@
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using A1DataType = XDataType;
using B0DataType = F8;
@@ -40,7 +47,7 @@ using A0Layout = Row;
using B0Layout = Col;
using CLayout = Row;
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
void preShuffleBuffer(const F8* src, F8* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
@@ -71,6 +78,8 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
@@ -92,7 +101,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ADataType, BDataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType, B0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -135,7 +144,7 @@ int main(int argc, char* argv[])
StrideA = K;
StrideB = K;
StrideE = N;
StrideC = N;
}
else
{
@@ -147,8 +156,8 @@ int main(int argc, char* argv[])
exit(0);
}
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
@@ -166,19 +175,19 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<A1DataType> a_m_k_scale(f_host_tensor_descriptor(
M, (K + Scale_Block_K - 1) / Scale_Block_K, Scale_Stride_AM, A0Layout{}));
M, (K + ScaleBlockSize - 1) / ScaleBlockSize, Scale_Stride_AM, A0Layout{}));
Tensor<B0DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b_k_n_scale(f_host_tensor_descriptor(
(K + Scale_Block_K - 1) / Scale_Block_K, N, Scale_Stride_BN, B0Layout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, ELayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, ELayout{}));
(K + ScaleBlockSize - 1) / ScaleBlockSize, N, Scale_Stride_BN, B0Layout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
std::cout << "e_m_n: " << c_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
@@ -206,9 +215,9 @@ int main(int argc, char* argv[])
DeviceMem a_scale_device_buf(sizeof(A1DataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(B0DataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(B1DataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a0_m_k.mData.data());
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
@@ -226,9 +235,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
auto cde_element_op = CElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
@@ -243,7 +250,7 @@ int main(int argc, char* argv[])
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(e_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
@@ -252,7 +259,7 @@ int main(int argc, char* argv[])
StrideB,
Scale_Stride_BN,
StrideC,
KBatch,
1, // KBatch
a_element_op,
b_element_op,
cde_element_op);
@@ -292,30 +299,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<AccDataType> c_m_n({M, N});
Tensor<float> a_m_k({M, K});
Tensor<float> b_k_n({K, N});
for(int m = 0; m < M; m++)
{
for(int k = 0; k < K; k++)
{
a_m_k(m, k) = ck::type_convert<float>(a0_m_k(m, k)) *
a1_m_k(m / Scale_Block_M, k / Scale_Block_K);
}
}
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) *
b1_k_n(k / Scale_Block_K, n / Scale_Block_N);
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
BDataType,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<A0DataType,
B0DataType,
CDataType,
AccDataType,
XDataType,
@@ -338,10 +323,10 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
? 0
: 1;
}

View File

@@ -288,7 +288,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_bufs[I0](Number<a_scale_offset>{}) =
a_scale_thread_bufs(I0)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
@@ -318,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs[I0](Number<b_scale_offset>{}) =
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
@@ -358,7 +358,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_bufs[I1](Number<a_scale_offset>{}) =
a_scale_thread_bufs(I1)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
@@ -388,7 +388,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs[I1](Number<b_scale_offset>{}) =
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
@@ -542,7 +542,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_bufs[mfma_reg_buf](Number<a_scale_offset>{}) =
a_scale_thread_bufs(mfma_reg_buf)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
@@ -573,7 +573,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs[mfma_reg_buf](Number<b_scale_offset>{}) =
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,

View File

@@ -79,6 +79,8 @@ struct DeviceGemmMX_BPreshuffle : public BaseOperator
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device

View File

@@ -216,6 +216,8 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle
using Argument = typename GridwiseGemm::Argument;
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
struct Invoker : public BaseInvoker
{
@@ -313,32 +315,32 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
@@ -350,21 +352,21 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}

View File

@@ -141,7 +141,7 @@ template <typename ALayout,
BDataType, // TODO: Hardcode them and remove from the list of template parameters
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemmMX_xdl_cshuffle_v3
struct GridwiseGemmMX_xdl_cshuffle_v3_b_preshuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -172,7 +172,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>;
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::k_per_blk);
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane =
@@ -1227,7 +1227,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
// const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
@@ -1417,7 +1417,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
@@ -1906,7 +1905,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =