diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 6f592aad74..0e758bbbee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -222,6 +222,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1 __device__ void Run(const AGridDesc& a_grid_desc, @@ -235,6 +236,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); StaticallyIndexedArray{}> b_thread_bufs; constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + StaticallyIndexedArray{}> b_thread_dequant_bufs; + // Global prefetch A1 B1 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); b_blockwise_copy.Run(b_grid_desc, @@ -279,6 +286,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1VGPR dequant + b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); // Initialize C c_thread_buf.Clear(); @@ -316,9 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; }); using mfma_input_type = @@ -348,6 +362,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1VGPR dequant + b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(mfma_reg_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(mfma_reg_buf)); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -382,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; }); @@ -411,6 +432,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1VGPR dequant + b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); __builtin_amdgcn_sched_barrier(0); @@ -425,7 +453,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; }); @@ -458,7 +486,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 5d776c6724..1e4ceebeeb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; - // const BElementwiseOperation b_element_op{}; + const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1205,8 +1205,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< BDataType, - // BDataType, - ADataType, + BDataType, decltype(b_grid_desc_bpreshuffled), decltype(b_block_desc_bk0_n_bk1), Sequence{}, I1, Number{}, Number{}>, @@ -1220,18 +1219,24 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle 0, KPack * (get_thread_local_1d_id() % warpSize))); + // B: VGRP->VGPR dequantization + auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeTypeA, + decltype(b_block_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<0, 1, 2, 3>, + 3, + BK1Number>(b_element_op); + // LDS allocation for A and B: be careful of alignment // Cast after lds auto a_block_buf = make_dynamic_buffer( static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - // auto b_block_buf = make_dynamic_buffer( - // reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - // sizeof(ADataType) / - // APackedSize), - // b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); @@ -1255,6 +1260,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle b_grid_buf, b_block_buf, b_block_slice_copy_step, + + // B: VGRP->VGPR dequantization + b_thread_dequant_copy, c_thread_buf, num_k_block_main_loop); @@ -1514,7 +1522,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; - // const BElementwiseOperation b_element_op{}; + const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1604,6 +1612,18 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle 0, KPack * (get_thread_local_1d_id() % warpSize))); + // B: VGRP->VGPR dequantization + auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeTypeA, + decltype(b_block_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<0, 1, 2, 3>, + 3, + BK1Number>(b_element_op); + // LDS allocation for A and B: be careful of alignment auto a_block_buf_ping = make_dynamic_buffer( static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1636,6 +1656,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle b_grid_buf, b_block_bufs, b_block_slice_copy_step, + + // B: VGRP->VGPR dequantization + b_thread_dequant_copy, c_thread_buf, num_k_block_main_loop); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index ef562ba744..741b8539dc 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -287,6 +287,7 @@ struct ThreadwiseTensorSliceTransfer_v2 // loop over tensor and copy constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); +#if 0 if constexpr(is_same, pk_i4_t>::value) { static_for<0, num_access, 1>{}([&](auto idx_1d) { @@ -352,12 +353,13 @@ struct ThreadwiseTensorSliceTransfer_v2 }); } else +#endif { static_for<0, num_access, 1>{}([&](auto idx_1d) { - typename vector_type_maker::type src_vector; + typename vector_type_maker::type src_vector; using src_vector_t = - typename vector_type_maker::type::type; + typename vector_type_maker::type::type; constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); const bool is_src_valid = @@ -365,24 +367,24 @@ struct ThreadwiseTensorSliceTransfer_v2 // copy data from src_buf into src_vector src_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + src_buf.template Get(src_coord_.GetOffset() / PackedSize, is_src_valid); // copy data from src_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { constexpr index_t dst_offset = dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + i * src_scalar_step_in_vector); if constexpr(InvalidElementAsNaN) { - dst_buf(Number{}) = + dst_buf(Number{}) = is_src_valid ? type_convert(src_vector.template AsType()[i]) : NumericLimits::QuietNaN(); } else { - dst_buf(Number{}) = + dst_buf(Number{}) = type_convert(src_vector.template AsType()[i]); } }); @@ -1544,6 +1546,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic using Index = MultiIndex; + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic( const ElementwiseOperation& element_op) : element_op_{element_op} @@ -1598,26 +1607,70 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, num_access, 1>{}([&](auto idx_1d) { - constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + if constexpr(is_same, pk_i4_t>::value) + { + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type src_tmp_vector; - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - DstData v; + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); - // apply element-wise operation - element_op_(v, src_buf[Number{}]); + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; - // apply type convert - dst_buf(Number{}) = v; + constexpr index_t pack_size = 8; + + static_assert(DstScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + + static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); }); - }); + } + else + { + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + DstData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + // apply type convert + dst_buf(Number{}) = v; + }); + }); + } } ElementwiseOperation element_op_;