mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
move b thread dequant copy to blockwise.
This commit is contained in:
@@ -222,7 +222,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BThreadTransfer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
@@ -236,7 +235,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
BThreadTransfer& b_thread_dequant_copy,
|
||||
// BThreadTransfer& b_thread_dequant_copy,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
@@ -287,7 +286,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
@@ -362,12 +361,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(local_read_buf));
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(local_read_buf));
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -432,7 +431,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1),
|
||||
b_thread_desc_,
|
||||
@@ -528,6 +527,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
KPack>;
|
||||
|
||||
const PassThrough b_element_op{};
|
||||
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
// const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -1219,18 +1219,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
BK1Number>(b_element_op);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
// Cast after lds
|
||||
@@ -1260,9 +1248,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
b_thread_dequant_copy,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
@@ -1522,7 +1507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
// const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -1612,18 +1597,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BK1Number>(b_element_op);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -1656,9 +1629,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
b_thread_dequant_copy,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
|
||||
@@ -1573,7 +1573,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
Reference in New Issue
Block a user