mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
fix update
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
@@ -21,12 +22,18 @@
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F8;
|
||||
using A1DataType = XDataType;
|
||||
using B0DataType = F8;
|
||||
@@ -40,7 +47,7 @@ using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
|
||||
void preShuffleBuffer(const F8* src, F8* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
@@ -71,6 +78,8 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
@@ -92,7 +101,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ADataType, BDataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType, B0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -135,7 +144,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
StrideC = N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -147,8 +156,8 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -166,19 +175,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<A0DataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<A1DataType> a_m_k_scale(f_host_tensor_descriptor(
|
||||
M, (K + Scale_Block_K - 1) / Scale_Block_K, Scale_Stride_AM, A0Layout{}));
|
||||
M, (K + ScaleBlockSize - 1) / ScaleBlockSize, Scale_Stride_AM, A0Layout{}));
|
||||
Tensor<B0DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B0DataType> b_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b_k_n_scale(f_host_tensor_descriptor(
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K, N, Scale_Stride_BN, B0Layout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, ELayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, ELayout{}));
|
||||
(K + ScaleBlockSize - 1) / ScaleBlockSize, N, Scale_Stride_BN, B0Layout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -206,9 +215,9 @@ int main(int argc, char* argv[])
|
||||
DeviceMem a_scale_device_buf(sizeof(A1DataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(B0DataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_scale_device_buf(sizeof(B1DataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
|
||||
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
|
||||
|
||||
@@ -226,9 +235,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
auto cde_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
@@ -243,7 +250,7 @@ int main(int argc, char* argv[])
|
||||
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -252,7 +259,7 @@ int main(int argc, char* argv[])
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideC,
|
||||
KBatch,
|
||||
1, // KBatch
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
@@ -292,30 +299,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
Tensor<float> a_m_k({M, K});
|
||||
Tensor<float> b_k_n({K, N});
|
||||
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
a_m_k(m, k) = ck::type_convert<float>(a0_m_k(m, k)) *
|
||||
a1_m_k(m / Scale_Block_M, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) *
|
||||
b1_k_n(k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
|
||||
BDataType,
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<A0DataType,
|
||||
B0DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
XDataType,
|
||||
@@ -338,10 +323,10 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
|
||||
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
@@ -288,7 +288,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_bufs[I0](Number<a_scale_offset>{}) =
|
||||
a_scale_thread_bufs(I0)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
@@ -318,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs[I0](Number<b_scale_offset>{}) =
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
@@ -358,7 +358,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_bufs[I1](Number<a_scale_offset>{}) =
|
||||
a_scale_thread_bufs(I1)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
@@ -388,7 +388,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs[I1](Number<b_scale_offset>{}) =
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
@@ -542,7 +542,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_bufs[mfma_reg_buf](Number<a_scale_offset>{}) =
|
||||
a_scale_thread_bufs(mfma_reg_buf)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
@@ -573,7 +573,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs[mfma_reg_buf](Number<b_scale_offset>{}) =
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
|
||||
@@ -79,6 +79,8 @@ struct DeviceGemmMX_BPreshuffle : public BaseOperator
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -216,6 +216,8 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
@@ -313,32 +315,32 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
@@ -350,21 +352,21 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
kernel_gemm_xdl_cshuffle_v3_b_preshuffle<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ template <typename ALayout,
|
||||
BDataType, // TODO: Hardcode them and remove from the list of template parameters
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
struct GridwiseGemmMX_xdl_cshuffle_v3_b_preshuffle
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -172,7 +172,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>;
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::k_per_blk);
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk);
|
||||
|
||||
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KLane =
|
||||
@@ -1227,7 +1227,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
// const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -1417,7 +1417,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
@@ -1906,7 +1905,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
|
||||
Reference in New Issue
Block a user