mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)
[CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
3af1a0aafc
commit
4c626aeaa6
@@ -185,16 +185,35 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaItera
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
AttrNumAccessA>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
4,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
|
||||
@@ -17,13 +17,47 @@ enum class WGAttrNumAccessEnum
|
||||
Invalid = -1
|
||||
};
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess>
|
||||
struct get_wgattr_num_access
|
||||
{
|
||||
private:
|
||||
static constexpr index_t getAccesses()
|
||||
{
|
||||
if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Single)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Double)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Quad)
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "unsupported AttrNumAccess");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto value = getAccesses();
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
|
||||
struct WarpGemmAttributeMfma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
|
||||
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
|
||||
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
|
||||
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
|
||||
|
||||
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
@@ -44,12 +78,13 @@ struct WarpGemmAttributeMfma
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
template <index_t kMNLane>
|
||||
template <index_t kMNLane, index_t AttrNumAccessV_>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
static_assert(kKPerThread % AttrNumAccessV_ == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
if constexpr(AttrNumAccessV_ == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
@@ -57,18 +92,48 @@ struct WarpGemmAttributeMfma
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
{
|
||||
// AttrNumAccess splits the kABKPerLane
|
||||
// We can split them but still have them contiguous (packed) or have them interleaved.
|
||||
// The reason to split the dimension but still have it packed is to match load transpose
|
||||
// encoding when A and B use different AttrNumAccess (they have different types in LDS)
|
||||
// Example
|
||||
// A: 16bit, B: 8bit
|
||||
// Load transpose B: lane0 -> K=0..7 (only 1 instruction)
|
||||
// Load transpose A: lane0 -> K=0..3 first instruction, K=4..7 second instruction
|
||||
// In this way the data in register are consistent between A and B
|
||||
if constexpr(UsePackNumAccess)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<Impl::kABKLane,
|
||||
AttrNumAccessV_,
|
||||
Impl::kABKPerLane / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<1, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV_,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane, AttrNumAccessAV>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane, AttrNumAccessBV>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -121,14 +186,19 @@ struct WarpGemmAttributeMfma
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
|
||||
struct WarpGemmAttributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
|
||||
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
|
||||
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
|
||||
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
|
||||
|
||||
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
@@ -151,14 +221,15 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
|
||||
"Multi-block on both M & N directions is not supported");
|
||||
|
||||
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock>
|
||||
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock, index_t AttrNumAccessV_>
|
||||
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(kMNBlock == 1 && kNMBlock == 1)
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
static_assert(kKPerThread % AttrNumAccessV_ == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
if constexpr(AttrNumAccessV_ == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
@@ -166,21 +237,40 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
{
|
||||
if constexpr(UsePackNumAccess)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<Impl::kABKLane,
|
||||
AttrNumAccessV_,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<1, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV_,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
static_assert(AttrNumAccessV_ == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each M/N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
@@ -193,7 +283,7 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(1 < kMNBlock && kNMBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
static_assert(AttrNumAccessV_ == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
@@ -245,10 +335,14 @@ struct WarpGemmAttributeMfmaIterateK
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
decltype(get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock>());
|
||||
using BWarpDstrEncoding =
|
||||
decltype(get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock>());
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane,
|
||||
Impl::kAMBlock,
|
||||
Impl::kBNBlock,
|
||||
AttrNumAccessAV>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane,
|
||||
Impl::kBNBlock,
|
||||
Impl::kAMBlock,
|
||||
AttrNumAccessBV>());
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
|
||||
@@ -24,9 +24,10 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = ESingle>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = ESingle,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
struct Dispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -78,6 +79,10 @@ template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble, ESingle>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad, ESingle>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<EDouble>; };
|
||||
@@ -166,9 +171,10 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
AType,
|
||||
BType,
|
||||
@@ -179,6 +185,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user