From 77ad000e8a1e5f943fc8b0bc1ff2c408fb78bb33 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 15 Oct 2024 10:08:44 -0700 Subject: [PATCH] clean --- example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp | 2 +- example/01_gemm/run_gemm_example_v2.inc | 54 ++++--------------- .../element/unary_element_wise_operation.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 24 ++++----- 4 files changed, 24 insertions(+), 58 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp index 617af31d29..29b5b725b4 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp @@ -48,7 +48,7 @@ using DeviceGemmV2Instance = S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 32, 32, 0, + 2, 32, 32, 1, 1, 1, S<1, 16, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 1b042fe0ee..c7bcb07731 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -88,10 +88,6 @@ inline __host__ __device__ constexpr double get_atol() template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { -#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) - static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); -#endif - using namespace ck::literals; auto M = problem_size.M; @@ -169,25 +165,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; -#ifdef BUILD_INT4_EXAMPLE - DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * - c_m_n_device_result.mDesc.GetElementSpaceSize()); - - const Tensor a_m_k_converted(a_m_k); - const Tensor b_k_n_converted(b_k_n); - - a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); -#else DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); -#endif DeviceMem workspace; auto a_element_op = AElementOp{}; @@ -200,15 +183,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) float ave_time = 0; auto argument = gemm.MakeArgument( -#ifdef BUILD_INT4_EXAMPLE - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), -#else static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), -#endif M, N, K, @@ -238,17 +215,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) //ref_invoker.Run(ref_argument); - ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); -#ifdef BUILD_INT4_EXAMPLE - Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); - - c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); - - c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); - - return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); -#else - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); //pass &= ck::utils::check_err(c_m_n_device_result, // c_m_n_host_result, @@ -256,18 +224,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) // get_rtol(), // get_atol()); - //for(int i = 0; i < M; i++) - //{ - // for(int j = 0; j < N; j++) - // { - // std::cout << ck::type_convert(c_m_n_device_result(i, j)) << ","; - // } - // std::cout << std::endl; - //} -#endif + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + std::cout << ck::type_convert(c_m_n_device_result(i, j)) << ","; + } + std::cout << std::endl; + } } - if(config.time_kernel) { ave_time = diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 83d2193eef..618813b781 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -25,7 +25,7 @@ struct PassThroughPack2 __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const { -#if 0 +#if 1 uint8_t x_u8 = ck::bit_cast(x); uint8_t x_l = (x_u8 & 0x0f) >> 0; uint8_t x_h = (x_u8 & 0xf0) >> 4; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 1037ec410f..3c051d9755 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -945,10 +945,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align); + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize, max_lds_align); + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = @@ -957,8 +957,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), c_block_size * sizeof(CShuffleDataType)); } @@ -1316,16 +1316,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align); + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned), - b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize); + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); @@ -1711,23 +1711,23 @@ struct GridwiseGemm_xdl_cshuffle_v3 // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align); + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_ping = make_dynamic_buffer( static_cast(static_cast(p_shared_0) + a_block_space_size_aligned * sizeof(ADataType)), - b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize); + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_pong = make_dynamic_buffer( bit_cast(bit_cast(p_shared_1) + a_block_space_size_aligned * sizeof(ADataType)), - b_block_desc_bk0_n_bk1.GetElementSpaceSize() / BPackedSize); + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);