mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)
[CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
3af1a0aafc
commit
4c626aeaa6
@@ -20,8 +20,23 @@ struct GemmPipelineAgBgCrImplBase
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>, ADataType, BInDataType>;
|
||||
|
||||
template <typename T>
|
||||
using has_bcastpolicy_type = decltype(T::BCastPolicy);
|
||||
|
||||
static constexpr bool IsBCastPolicyBeforeLDSWrite = [] {
|
||||
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
|
||||
{
|
||||
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}();
|
||||
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite, ADataType, BInDataType>;
|
||||
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
@@ -226,6 +241,12 @@ struct GemmPipelineAgBgCrImplBase
|
||||
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
{
|
||||
// with pk_int4_t load transpose the LDS type is always BDataType
|
||||
using ADataTypeLDS =
|
||||
std::conditional_t<std::is_same_v<typename Problem::ADataType, pk_int4_t>,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::ADataType>;
|
||||
|
||||
auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
@@ -238,9 +259,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_lds_load_tile_distr = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename ALdsLoadTileDistr::DstrEncode,
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
typename InputTileDistributionTraits<typename ALdsLoadTileDistr::DstrEncode,
|
||||
ADataTypeLDS>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsLoadTileDistr{};
|
||||
}();
|
||||
@@ -313,10 +333,9 @@ struct GemmPipelineAgBgCrImplBase
|
||||
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
using BLdsDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
using BLdsDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
auto b_lds_load_tile_distr = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
|
||||
@@ -10,6 +10,12 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct CastPolicy
|
||||
{
|
||||
BeforeLDSWrite,
|
||||
AfterLDSRead,
|
||||
};
|
||||
|
||||
enum struct GemmPipelineScheduler
|
||||
{
|
||||
Default,
|
||||
|
||||
@@ -80,6 +80,21 @@ struct UniversalGemmBasePolicy
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using has_bcastpolicy_type = decltype(T::BCastPolicy);
|
||||
|
||||
template <typename Problem>
|
||||
static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] {
|
||||
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
|
||||
{
|
||||
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
@@ -305,11 +320,11 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -589,15 +604,14 @@ struct UniversalGemmBasePolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
using BDataType = std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
@@ -739,13 +753,13 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t KPerBlock = std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
|
||||
? Problem::BlockGemmShape::kK / 2
|
||||
: Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// If we cast before writing to LDS, the vectorsize is defined by the A type
|
||||
// since the assumption is that A type is going to be the B LDS type
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
constexpr index_t VecLoadSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
|
||||
? 4
|
||||
IsBCastPolicyBeforeLDSWrite
|
||||
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
|
||||
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
using BLayout = remove_cvref_t<
|
||||
@@ -855,10 +869,10 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
|
||||
@@ -900,7 +914,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
sizeof(BDataType) < sizeof(ADataType),
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user