mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
init b preshuffle dequant in VGPR.
This commit is contained in:
@@ -222,6 +222,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BThreadTransfer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
@@ -235,6 +236,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
BThreadTransfer& b_thread_dequant_copy,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
@@ -242,12 +244,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
@@ -279,6 +286,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I0));
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
@@ -316,9 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_dequant_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -348,6 +362,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(mfma_reg_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(mfma_reg_buf));
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -382,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
@@ -411,6 +432,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I1));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -425,7 +453,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
@@ -458,7 +486,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
|
||||
@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
// const BElementwiseOperation b_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -1205,8 +1205,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
|
||||
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
// BDataType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
@@ -1220,18 +1219,24 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
BK1Number>(b_element_op);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
// auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
// reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
|
||||
// sizeof(ADataType) /
|
||||
// APackedSize),
|
||||
// b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
|
||||
|
||||
@@ -1255,6 +1260,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
b_thread_dequant_copy,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
@@ -1514,7 +1522,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
// const BElementwiseOperation b_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -1604,6 +1612,18 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
BK1Number>(b_element_op);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -1636,6 +1656,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
// B: VGRP->VGPR dequantization
|
||||
b_thread_dequant_copy,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
|
||||
@@ -287,6 +287,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
// loop over tensor and copy
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
#if 0
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
@@ -352,12 +353,13 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
});
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
const bool is_src_valid =
|
||||
@@ -365,24 +367,24 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
if constexpr(InvalidElementAsNaN)
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
dst_buf(Number<dst_offset / PackedSize>{}) =
|
||||
is_src_valid
|
||||
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
|
||||
: NumericLimits<DstData>::QuietNaN();
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
dst_buf(Number<dst_offset / PackedSize>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
}
|
||||
});
|
||||
@@ -1544,6 +1546,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
|
||||
const ElementwiseOperation& element_op)
|
||||
: element_op_{element_op}
|
||||
@@ -1598,26 +1607,70 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, DstScalarPerVector / PackedSize>::type src_tmp_vector;
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset / PackedSize>{}];
|
||||
});
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
constexpr index_t pack_size = 8;
|
||||
|
||||
static_assert(DstScalarPerVector % pack_size == 0, "");
|
||||
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
|
||||
static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack8{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
ElementwiseOperation element_op_;
|
||||
|
||||
Reference in New Issue
Block a user