Add Memory pipeline for AQuant Block Scale GEMM (#2987)

* WIP: add memory pipeline boiler plate code that compiles and works for one block

* WIP: tail handling works for memory pipeline

* WIP: numerical errors appears to have gone by adding block_sync_lds()

* fix: numerical error with memory pipeline by adding block_sync_lds() and new tail handler

* refactror: remove debug print statements and lints

* fix: remove redundant sync barriars

* chore: remove lint

* fix: remove unused code from tile handler and remove redundant block_sync_lds()

* fix: correct parent struct name for memory pipeline

* fix: remove static assert check from parent struct and add it to child struct because not all child structs needs to static assert

* fix: defer block sync lds to just before prefill
This commit is contained in:
Aviral Goel
2025-10-08 20:22:30 -04:00
committed by GitHub
parent e29151b533
commit e99356dabc
6 changed files with 489 additions and 12 deletions

View File

@@ -21,8 +21,6 @@ struct BaseGemmPipelineAgBgCrMem
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
@@ -174,7 +172,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;