mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Clean up pipeline
This commit is contained in:
@@ -69,12 +69,6 @@ CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t type_convert<bf16_t, bf8_t>(bf8_t x)
|
||||
{
|
||||
return float_to_bf16(bf8_to_float(x));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
|
||||
@@ -239,7 +239,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
BQRegBlockTile& bq_block_tensor,
|
||||
const BQRegBlockTile& bq_block_tensor,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
|
||||
@@ -27,9 +27,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
using BDqDataType = std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
|
||||
remove_cvref_t<typename Problem::ADataType>,
|
||||
BDataType>;
|
||||
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
|
||||
using BLDSType = std::conditional_t<IsCastBeforeLDS, BDqDataType, BDataType>;
|
||||
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
@@ -43,17 +46,16 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BDqPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDqDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BLDSPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BLDSType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -90,8 +92,6 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -175,6 +175,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
static constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
static constexpr bool is_b_row_major =
|
||||
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
@@ -217,7 +222,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16
|
||||
B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
@@ -233,7 +238,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16 ? 8 : 4;
|
||||
B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
@@ -316,6 +321,139 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TileType, typename CastTileType, typename ScaleTileType>
|
||||
CK_TILE_DEVICE static void
|
||||
ScaleTile(TileType& block_tile, CastTileType& block_tile_cast, ScaleTileType& scale_tile)
|
||||
{
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
{
|
||||
constexpr auto b_block = TileType::get_distributed_spans();
|
||||
constexpr auto idx1_js = tile_distributed_index<0>{};
|
||||
|
||||
// Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4
|
||||
// on gfx950
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<BDqDataType, half_t>)
|
||||
{
|
||||
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDqDataType, bf16_t>)
|
||||
{
|
||||
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
|
||||
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
|
||||
auto scale = scale_tile(i_j_idx_scale);
|
||||
auto b_scale_uint = uint32_t(scale.data) << 23;
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
if constexpr(idx1.impl_.at(0) % BPackedSize == 0)
|
||||
{
|
||||
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0)>{};
|
||||
constexpr auto idx1_hi =
|
||||
tile_distributed_index<idx1.impl_.at(0) + 1>{};
|
||||
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
|
||||
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
|
||||
auto b_pack = block_tile(i_j_idx);
|
||||
auto cvt =
|
||||
pk_mxfp4_to_compute_v2(b_pack, bit_cast<float>(b_scale_uint));
|
||||
block_tile_cast(i_j_idx_lo) = cvt.x;
|
||||
block_tile_cast(i_j_idx_hi) = cvt.y;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto b_pack = block_tile(i_j_idx);
|
||||
block_tile_cast(i_j_idx) = type_convert<BDqDataType>(
|
||||
type_convert<float>(b_pack) * bit_cast<float>(b_scale_uint));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WindowType, typename TileType, typename ElementwiseFunc>
|
||||
CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window,
|
||||
const TileType& block_tile,
|
||||
const ElementwiseFunc& element_func) const
|
||||
{
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, block_tile);
|
||||
Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(lds_window, block_tile, element_func);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WindowType,
|
||||
typename TileType,
|
||||
typename TileTypeCast,
|
||||
typename ElementwiseFunc>
|
||||
CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window,
|
||||
const TileType& block_tile,
|
||||
const TileTypeCast& block_tile_cast,
|
||||
const ElementwiseFunc& element_func) const
|
||||
{
|
||||
// Fill LDS and apply the scale if IsCastBeforeLDS
|
||||
auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) {
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
{
|
||||
return b_block_tile_cast;
|
||||
}
|
||||
else
|
||||
{
|
||||
return b_block_tile_orig;
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BLDSType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast));
|
||||
Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemmType,
|
||||
typename AWindowType,
|
||||
typename BWindowType,
|
||||
typename QTileType>
|
||||
CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm,
|
||||
const AWindowType& a_lds_window,
|
||||
const BWindowType& b_lds_window,
|
||||
const QTileType& q_block_tile) const
|
||||
{
|
||||
// Load from LDS
|
||||
// It can apply the scale and cast if we scale after reading from LDS
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
{
|
||||
block_gemm.LocalPrefetch(a_lds_window, b_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm.LocalPrefetch(a_lds_window, b_lds_window, q_block_tile);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -331,6 +469,8 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Pipeline checks
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -340,11 +480,8 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
"A/B/BQ Dram block window should have the same data type as appropriate "
|
||||
"([A|B|BQ]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_bq_col_major =
|
||||
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
@@ -393,6 +530,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){};
|
||||
|
||||
// This defines the scaled and casted block tile for B matrix.
|
||||
// Effectively, it is used only if we scale and cast before writing to LDS.
|
||||
auto bdq_block_tile = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeBRegTileDistribution<Problem>());
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
@@ -405,7 +547,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_fp4_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
@@ -419,137 +561,44 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){};
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
// BDataType
|
||||
auto b_block_tile = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeBRegTileDistribution<Problem>());
|
||||
// prefetch stages
|
||||
|
||||
// Vmem -> Vgpr 0
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Vmem -> Vgpr 0 (Q matrix)
|
||||
// Scale and cast tile before writing to LDS (if IsCastBeforeLDS)
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
|
||||
|
||||
constexpr auto idx1_js = tile_distributed_index<0>{};
|
||||
constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans();
|
||||
|
||||
// Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 on
|
||||
// gfx950
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<BDqDataType, half_t>)
|
||||
{
|
||||
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDqDataType, bf16_t>)
|
||||
{
|
||||
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
|
||||
auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) {
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
{
|
||||
return b_block_tile_cast;
|
||||
}
|
||||
else
|
||||
{
|
||||
return b_block_tile_orig;
|
||||
}
|
||||
};
|
||||
|
||||
auto apply_scale_func = [&]() {
|
||||
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
|
||||
auto scale = bq_block_tile(i_j_idx_scale);
|
||||
auto b_scale_uint = uint32_t(scale.data) << 23;
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
if constexpr(idx1.impl_.at(0) % BPackedSize == 0)
|
||||
{
|
||||
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0)>{};
|
||||
constexpr auto idx1_hi =
|
||||
tile_distributed_index<idx1.impl_.at(0) + 1>{};
|
||||
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
|
||||
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
|
||||
auto b_pack = b_fp4_block_tile(i_j_idx);
|
||||
|
||||
auto cvt =
|
||||
pk_mxfp4_to_compute_v2(b_pack, bit_cast<float>(b_scale_uint));
|
||||
b_block_tile(i_j_idx_lo) = cvt.x;
|
||||
b_block_tile(i_j_idx_hi) = cvt.y;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto b_pack = b_fp4_block_tile(i_j_idx);
|
||||
b_block_tile(i_j_idx) = type_convert<BDqDataType>(
|
||||
type_convert<float>(b_pack) * bit_cast<float>(b_scale_uint));
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
apply_scale_func();
|
||||
|
||||
// initialize C
|
||||
// initialize C tile to zero
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile_);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func);
|
||||
}
|
||||
// Vgpr -> LDS 0
|
||||
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
|
||||
|
||||
// Vmem -> Vgpr 1
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// If we scale and cast before writing to LDS,
|
||||
// we need to read another tile of Q matrix from Vmem, then scale and cast tile
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
{
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
|
||||
apply_scale_func();
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_sync_lds();
|
||||
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS -> Vgpr 0
|
||||
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
@@ -560,58 +609,34 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile_);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func);
|
||||
}
|
||||
// Vgpr -> LDS
|
||||
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
|
||||
|
||||
// Vmem -> Vgpr
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Vmem -> Vgpr (Q matrix)
|
||||
// Scale and cast tile before writing to LDS (if IsCastBeforeLDS)
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
|
||||
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
apply_scale_func();
|
||||
|
||||
// Consume tile
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
else
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
|
||||
// LDS -> Vgpr
|
||||
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
// b_block_stride +=1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile);
|
||||
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
{
|
||||
@@ -621,50 +646,31 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
}
|
||||
else
|
||||
{
|
||||
// If we scale and cast after reading from LDS,
|
||||
// we didn't read the second tile of Q matrix from Vmem during prefetch stages,
|
||||
// so we need to read the last tile here.
|
||||
// This is not a problem because we have all block_gemm instructions to hide the
|
||||
// latency.
|
||||
if constexpr(!IsCastBeforeLDS)
|
||||
{
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
}
|
||||
|
||||
// Consume second to last tile
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile_);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func);
|
||||
}
|
||||
// Vgpr -> LDS last tile
|
||||
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(IsCastBeforeLDS)
|
||||
{
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
|
||||
}
|
||||
|
||||
// LDS -> Vgpr last tile
|
||||
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
|
||||
|
||||
// Consume last tile
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -690,13 +696,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
using BElementwise = std::conditional_t<IsCastBeforeLDS, BDqDataType, BDataType>;
|
||||
ck_tile::ignore = n;
|
||||
ck_tile::ignore = n;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BElementwise& b) { return b; },
|
||||
[](const BLDSType& b) { return b; },
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
|
||||
Reference in New Issue
Block a user