mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
clean
This commit is contained in:
@@ -48,7 +48,7 @@ using DeviceGemmV2Instance =
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
2, 32, 32, 1,
|
||||
1, 1, S<1, 16, 1, 8>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
|
||||
@@ -88,10 +88,6 @@ inline __host__ __device__ constexpr double get_atol()
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
#endif
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
@@ -169,25 +165,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
|
||||
c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
|
||||
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
|
||||
#else
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
#endif
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
@@ -200,15 +183,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -238,17 +215,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
//ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
|
||||
|
||||
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
//pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
// c_m_n_host_result,
|
||||
@@ -256,18 +224,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
// get_rtol<CDataType>(),
|
||||
// get_atol<CDataType>());
|
||||
|
||||
//for(int i = 0; i < M; i++)
|
||||
//{
|
||||
// for(int j = 0; j < N; j++)
|
||||
// {
|
||||
// std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
//}
|
||||
#endif
|
||||
for(int i = 0; i < M; i++)
|
||||
{
|
||||
for(int j = 0; j < N; j++)
|
||||
{
|
||||
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
|
||||
@@ -25,7 +25,7 @@ struct PassThroughPack2
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
@@ -945,10 +945,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align);
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize, max_lds_align);
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
@@ -957,8 +957,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
|
||||
b_block_space_size_aligned * sizeof(BDataType)),
|
||||
return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -1316,16 +1316,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align);
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
reinterpret_cast<BDataType*>(static_cast<ADataType*>(p_shared) +
|
||||
a_block_space_size_aligned),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize);
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
|
||||
@@ -1711,23 +1711,23 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align);
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
|
||||
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(static_cast<char*>(p_shared_0) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize);
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
|
||||
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize);
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
|
||||
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
|
||||
|
||||
Reference in New Issue
Block a user