mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Support transposed C tile in Aquant (#2679)
The performance of Aquant has increased after enabling transposed C. Do not need to exchange AQ elements among lanes after enabling transposed C as one thread only holds data from one row.
This commit is contained in:
@@ -158,6 +158,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -359,63 +360,181 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of the
|
||||
// view is composed of a row from a warp tile within an AQ block tile.
|
||||
// Multiple warp tile rows that belong to the same block tile are laid
|
||||
// out as consecutive rows.
|
||||
//
|
||||
// When we need to multiply a C warp tile with an AQ warp tile, thread 0
|
||||
// in the warp will load AQ_warp_tile[0], thread 1 will load
|
||||
// AQ_warp_tile[1], and so on, up to thread 63, which will load
|
||||
// AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS in
|
||||
// this context, but we use cross-lane operations to access the data.
|
||||
// (Cross-lane operations are faster than using LDS.)
|
||||
//
|
||||
// Note that when the size of the AQ warp tile is smaller than the warp
|
||||
// size, you need to pad the rows in the view to ensure that each thread
|
||||
// can read one element.
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
constexpr uint32_t kTileRowsOfCPerThread = 4;
|
||||
if constexpr(Traits::TransposeC) // transposed C
|
||||
{
|
||||
static_assert(false,
|
||||
"It is not supported yet to enable both Preshuffle.");
|
||||
// TODO:
|
||||
// A new tile distribution is needed for the Preshuffle and
|
||||
// Transpose combination. For instance, with mnk at 16x16x32, lanes
|
||||
// 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ.
|
||||
}
|
||||
else
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of
|
||||
// the view is composed of a row from a warp tile within an AQ block
|
||||
// tile. Multiple warp tile rows that belong to the same block tile
|
||||
// are laid out as consecutive rows.
|
||||
//
|
||||
// When we need to multiply a C warp tile with an AQ warp tile,
|
||||
// thread 0 in the warp will load AQ_warp_tile[0], thread 1 will
|
||||
// load AQ_warp_tile[1], and so on, up to thread 63, which will load
|
||||
// AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS
|
||||
// in this context, but we use cross-lane operations to access the
|
||||
// data. (Cross-lane operations are faster than using LDS.)
|
||||
//
|
||||
// Note that when the size of the AQ warp tile is smaller than the
|
||||
// warp size, you need to pad the rows in the view to ensure that
|
||||
// each thread can read one element.
|
||||
constexpr auto tbuf_offset = number<
|
||||
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
constexpr uint32_t kTileRowsOfCPerThread = 4;
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
// For a warp tile of [16x16x32], take thread 0 as an example.
|
||||
// Its VGPR[0] stores the value from C_tile[0,0], VGPR[1] stores
|
||||
// C_tile[1,0], VGPR[2] stores C_tile[2,0], and VGPR[3] stores
|
||||
// C_tile[3,0]. This means VGPR[0] should be multiplied by
|
||||
// AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0], VGPR[2] by
|
||||
// AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
// For a warp tile of [16x16x32], take thread 0 as an
|
||||
// example. Its VGPR[0] stores the value from C_tile[0,0],
|
||||
// VGPR[1] stores C_tile[1,0], VGPR[2] stores C_tile[2,0],
|
||||
// and VGPR[3] stores C_tile[3,0]. This means VGPR[0] should
|
||||
// be multiplied by AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0],
|
||||
// VGPR[2] by AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
|
||||
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
|
||||
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else if constexpr(WarpGemm::kM == 32)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
((c_row >> 2) << 3) + (c_row & 0b11)) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
|
||||
}
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
|
||||
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else if constexpr(WarpGemm::kM == 32)
|
||||
{
|
||||
pull_from_lane =
|
||||
(__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
((c_row >> 2) << 3) + (c_row & 0b11)) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
|
||||
}
|
||||
auto& scale_reg =
|
||||
aq_block_tensor.get_thread_buffer()[mIter];
|
||||
|
||||
// cross lane ops
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword =
|
||||
ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2,
|
||||
__builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f =
|
||||
Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(Traits::TransposeC) // transposed C
|
||||
{
|
||||
constexpr index_t reg_offset = mIter * Traits::AQPerBlock + kQScale;
|
||||
constexpr auto tbuf_offset = number<
|
||||
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// Need to multiply aquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on aquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales
|
||||
// corresponding to the 16 elements held by lane 0 are held by lanes
|
||||
// 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
|
||||
// respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
// MIters per warp
|
||||
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
|
||||
|
||||
// Reg block offset based on mIter
|
||||
constexpr index_t reg_block_offset =
|
||||
((mIter / mIters_per_warp) * Traits::AQPerBlock);
|
||||
|
||||
constexpr index_t lane_base_offset =
|
||||
(mIter % mIters_per_warp) * WarpGemm::kM;
|
||||
|
||||
// Scale tensor offset along K
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
|
||||
constexpr auto tbuf_offset = number<
|
||||
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
// x CNLane
|
||||
constexpr uint32_t row_base =
|
||||
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
|
||||
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
|
||||
|
||||
constexpr uint32_t reg_offset_for_row_data =
|
||||
c_row / WarpGemm::kCMLane;
|
||||
|
||||
// Lane index to source scale from
|
||||
uint32_t src_lane_idx =
|
||||
lane_base_offset + row_base +
|
||||
(__lane_id() / WarpGemm::kN * kTileRows);
|
||||
|
||||
// Directly index into thread buffer corresponding to
|
||||
// desired row coefficient
|
||||
auto& scale_reg =
|
||||
aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
@@ -427,97 +546,19 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// Pull scale data across lanes
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2,
|
||||
__builtin_bit_cast(int, scale_reg_dword));
|
||||
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
|
||||
kA_cvt_scale * kB_cvt_scale);
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset +
|
||||
reg_offset_for_row_data] +=
|
||||
(c_warp_tensor
|
||||
.get_thread_buffer()[reg_offset_for_row_data] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Need to multiply aquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on aquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales corresponding
|
||||
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
|
||||
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
// MIters per warp
|
||||
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
|
||||
|
||||
// Reg block offset based on mIter
|
||||
constexpr index_t reg_block_offset =
|
||||
((mIter / mIters_per_warp) * Traits::AQPerBlock);
|
||||
|
||||
constexpr index_t lane_base_offset =
|
||||
(mIter % mIters_per_warp) * WarpGemm::kM;
|
||||
|
||||
// Scale tensor offset along K
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
// x CNLane
|
||||
constexpr uint32_t row_base =
|
||||
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
|
||||
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
|
||||
|
||||
constexpr uint32_t reg_offset_for_row_data =
|
||||
c_row / WarpGemm::kCMLane;
|
||||
|
||||
// Lane index to source scale from
|
||||
uint32_t src_lane_idx = lane_base_offset + row_base +
|
||||
(__lane_id() / WarpGemm::kN * kTileRows);
|
||||
|
||||
// Directly index into thread buffer corresponding to
|
||||
// desired row coefficient
|
||||
auto& scale_reg =
|
||||
aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// Pull scale data across lanes
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
c_block_tensor
|
||||
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
|
||||
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user