mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)
[CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
3af1a0aafc
commit
4c626aeaa6
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -101,20 +102,33 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
// 2. bf8, bf8, fp32 -> f32
|
||||
// 3. i4, fp8, (fp8/fp32) -> f32
|
||||
// 4. i4, bf8, (fp8/fp32) -> f32
|
||||
static_assert((std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>) &&
|
||||
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
|
||||
(std::is_same_v<BQDataType, float> ||
|
||||
std::is_same_v<BQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>) &&
|
||||
std::is_same_v<CDataType, fp32_t>);
|
||||
// 5. bf16, (bf16/bf8/fp8/fp4), e8m0 -> f32
|
||||
// 6. fp16, (fp16/fp8/bf8/fp4), e8m0 -> f32
|
||||
static_assert(
|
||||
is_any_of<ADataType, fp8_t, bf8_t, bf16_t, fp16_t>::value &&
|
||||
is_any_of<BDataType, fp8_t, bf8_t, pk_int4_t, bf16_t, pk_fp4_t, fp16_t>::value &&
|
||||
is_any_of<BQDataType, float, fp8_t, bf8_t, e8m0_t>::value &&
|
||||
is_any_of<ComputeDataType, fp8_t, bf8_t, bf16_t, fp16_t>::value &&
|
||||
std::is_same_v<CDataType, fp32_t>);
|
||||
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
template <typename T>
|
||||
using has_bcastpolicy_type = decltype(T::BCastPolicy);
|
||||
|
||||
static constexpr bool IsBCastPolicyBeforeLDSWrite = [] {
|
||||
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
|
||||
{
|
||||
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -127,9 +141,12 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
// OverrideBDataType is only used when BCastPolicy is CastBeforeLDSWrite for microscale.
|
||||
// In that case we use ADataType
|
||||
using OverrideBDataType = std::conditional_t<
|
||||
std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
|
||||
(std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>) ||
|
||||
Traits::IsBCastPolicyBeforeLDSWrite,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
@@ -176,57 +193,17 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
// Use gemm universal block distribution encoding instead of duplicating it
|
||||
using BlockGemmBase = BlockUniversalGemmAsBsCr<Problem_, Policy_, UnaryOpSize_>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
return BlockGemmBase::MakeABlockDistributionEncode();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
return BlockGemmBase::MakeBBlockDistributionEncode();
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -235,20 +212,24 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
{
|
||||
};
|
||||
|
||||
using BlockGemmImplBase = typename BlockUniversalGemmAsBsCr<Problem_, Policy_, UnaryOpSize_>::
|
||||
template BlockGemmImpl<GemmPipelineScheduler::Intrawave, Traits>;
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits> : public BlockGemmImplBase
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BlockGemmImplBase::a_warp_tile_;
|
||||
using BlockGemmImplBase::b_warp_tile_;
|
||||
using BlockGemmImplBase::BLdsTileDistr;
|
||||
// If we apply scale while reading from LDS, then we can use the operator() from
|
||||
// BlockUniversalGemmAsBsCr
|
||||
using BlockGemmImplBase::operator();
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
// static distributed tensor with LDS type
|
||||
using BTypeTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
BTypeTile b_warp_tile_lds_;
|
||||
|
||||
// Load from LDS (assumption is that the scale will be applied in the block gemm)
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
@@ -265,6 +246,107 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// Load from LDS and scale (then the tile can directly be consumed in the block gemm)
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
typename BQRegBlockTile,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
const BQRegBlockTile& bq_block_tensor,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
// Load tile from LDS
|
||||
|
||||
// Do not use load_int4_tile here because it will have support to cast from fp4 to
|
||||
// compute type, while here we want to only load from LDS and then apply the scale
|
||||
// and cast later
|
||||
if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
|
||||
if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_lds_ = load_tile_transpose(b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_lds_, b_block_window);
|
||||
}
|
||||
|
||||
// Apply scale and cast
|
||||
using BDataTypeRaw =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_fp4_t>, pk_fp4_t::type, BDataType>;
|
||||
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size;
|
||||
constexpr index_t thread_buffer_size = nelements / UnaryOpSize_;
|
||||
const element_wise::DequantPack8 elementwise_op{};
|
||||
using SrcVectorRawType = ext_vector_t<BDataTypeRaw, UnaryOpSize_ / BPackedSize>;
|
||||
using DstVectorType = ext_vector_t<ComputeDataType, UnaryOpSize_>;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// B scale register offset
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return ((nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::BQuantGroupSize::kN) *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
// Get B scale from thread buffer
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_f = float(scale_reg);
|
||||
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
// Thread buffers
|
||||
using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
|
||||
using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
|
||||
|
||||
BWarpThreadBuffer b_warp_thread_buffer;
|
||||
BLDSThreadBuffer b_lds_thread_buffer;
|
||||
|
||||
// Load thread buffer from tile (LDS type)
|
||||
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// Apply scale to B thread buffer and cast
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(
|
||||
b_warp_thread_buffer.template get_as<DstVectorType>()(i),
|
||||
b_lds_thread_buffer.template get_as<SrcVectorRawType>()[i],
|
||||
b_scale_f);
|
||||
});
|
||||
|
||||
// Store B thread buffer to tile (MMA type)
|
||||
b_warp_tile_.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
|
||||
b_warp_thread_buffer);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename BQBlockTensor,
|
||||
@@ -400,6 +482,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
// Read A and B from LDS
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
@@ -412,7 +495,24 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// Read A and B from LDS and apply scale to B
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
typename BQRegBlockTile,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
BQRegBlockTile bq_block_tile,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(
|
||||
a_block_window, b_block_window, bq_block_tile, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
// Apply scale after MMA
|
||||
template <typename CBlockTensor,
|
||||
typename BQBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
@@ -425,6 +525,16 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
block_gemm_impl_(c_block_tensor, bq_block_tensor, a_block_window, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
// Scale has already been applied to B, so this is using the gemm universal block implementation
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
|
||||
}
|
||||
|
||||
private:
|
||||
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user