mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK TILE] Fix bugs in AQuant preshuffle (#2700)
* [CK TILE] Fix bugs in AQuant preshuffle
- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.
* [CK TILE] Rename Preshuffle to PreshuffleQuant
The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.
* [CK TILE Block Scale] Use GemmConfig to save tile properties
- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
coded tile properties in `gemm_calc_aquant()`
* [CK TILE Block Scale] Rename GemmConfig used in block scale
- Remove unused GemmConfig
- Rename GemmConfig used in block scale
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(Traits::Preshuffle)
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of the
|
||||
// view is composed of a row from a warp tile within an AQ block tile.
|
||||
@@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
|
||||
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
auto pull_from_lane =
|
||||
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else if constexpr(WarpGemm::kM == 32)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
((c_row >> 2) << 3) + (c_row & 0b11)) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
|
||||
}
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||
|
||||
// cross lane ops
|
||||
|
||||
Reference in New Issue
Block a user