diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 0f3b99824d..f0a11533b0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -222,7 +222,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1 __device__ void Run(const AGridDesc& a_grid_desc, @@ -236,7 +235,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1VGPR dequant - b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, b_block_origin_idx, b_thread_bufs(I0), b_thread_desc_, @@ -362,12 +361,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1VGPR dequant - b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf), - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_dequant_bufs(local_read_buf)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -432,7 +431,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1VGPR dequant - b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, b_block_origin_idx, b_thread_bufs(I1), b_thread_desc_, @@ -528,6 +527,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 5f15f24288..9070aadc93 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; + // const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1219,18 +1219,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle 0, KPack * (get_thread_local_1d_id() % warpSize))); - // B: VGRP->VGPR dequantization - auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< - BDataType, - ComputeTypeA, - decltype(b_block_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - tensor_operation::element_wise::PassThrough, - Sequence{}, I1, Number{}, Number{}>, - Sequence<0, 1, 2, 3>, - 3, - BK1Number>(b_element_op); - // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1260,9 +1248,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle b_grid_buf, b_block_buf, b_block_slice_copy_step, - - // B: VGRP->VGPR dequantization - b_thread_dequant_copy, c_thread_buf, num_k_block_main_loop); @@ -1522,7 +1507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; + // const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1612,18 +1597,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle 0, KPack * (get_thread_local_1d_id() % warpSize))); - // B: VGRP->VGPR dequantization - auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< - BDataType, - ComputeTypeA, - decltype(b_block_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - tensor_operation::element_wise::PassThrough, - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BK1Number>(b_element_op); - // LDS allocation for A and B: be careful of alignment auto a_block_buf_ping = make_dynamic_buffer( static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1656,9 +1629,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle b_grid_buf, b_block_bufs, b_block_slice_copy_step, - - // B: VGRP->VGPR dequantization - b_thread_dequant_copy, c_thread_buf, num_k_block_main_loop); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 741b8539dc..6eebed7319 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1573,7 +1573,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic const SrcBuffer& src_buf, const DstDesc&, const DstSliceOriginIdx&, - DstBuffer& dst_buf) + DstBuffer& dst_buf) const { static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! Desc need to known at compile-time");