Clean up pipeline

This commit is contained in:
Enrico Degregori
2026-01-28 17:15:05 +00:00
parent 33f4f876cf
commit afd3d2bd10
3 changed files with 208 additions and 209 deletions

View File

@@ -69,12 +69,6 @@ CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2)
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
#undef CK_TILE_TYPE_CONVERT
template <>
CK_TILE_HOST_DEVICE constexpr bf16_t type_convert<bf16_t, bf8_t>(bf8_t x)
{
return float_to_bf16(bf8_to_float(x));
}
} // namespace ck_tile
#include "ck_tile/core/numeric/pk_fp4.hpp"

View File

@@ -239,7 +239,7 @@ struct BQuantBlockUniversalGemmAsBsCr
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
BQRegBlockTile& bq_block_tensor,
const BQRegBlockTile& bq_block_tensor,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{

View File

@@ -27,9 +27,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
remove_cvref_t<typename Problem::ADataType>,
BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
using BLDSType = std::conditional_t<IsCastBeforeLDS, BDqDataType, BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
@@ -43,17 +46,16 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BDqPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDqDataType>>::PackedSize;
static constexpr index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
? 2
: ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
static constexpr index_t BLDSPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BLDSType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -90,8 +92,6 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -175,6 +175,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = PipelineImplBase;
static constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_b_row_major =
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
@@ -217,7 +222,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16
B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
@@ -233,7 +238,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16 ? 8 : 4;
B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
@@ -316,6 +321,139 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
});
}
template <typename TileType, typename CastTileType, typename ScaleTileType>
CK_TILE_DEVICE static void
ScaleTile(TileType& block_tile, CastTileType& block_tile_cast, ScaleTileType& scale_tile)
{
if constexpr(IsCastBeforeLDS)
{
constexpr auto b_block = TileType::get_distributed_spans();
constexpr auto idx1_js = tile_distributed_index<0>{};
// Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4
// on gfx950
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<BDqDataType, half_t>)
{
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
}
else if constexpr(std::is_same_v<BDqDataType, bf16_t>)
{
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
}
else
{
static_assert(false, "unsupported compute type");
}
};
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto scale = scale_tile(i_j_idx_scale);
auto b_scale_uint = uint32_t(scale.data) << 23;
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
if constexpr(idx1.impl_.at(0) % BPackedSize == 0)
{
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0)>{};
constexpr auto idx1_hi =
tile_distributed_index<idx1.impl_.at(0) + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = block_tile(i_j_idx);
auto cvt =
pk_mxfp4_to_compute_v2(b_pack, bit_cast<float>(b_scale_uint));
block_tile_cast(i_j_idx_lo) = cvt.x;
block_tile_cast(i_j_idx_hi) = cvt.y;
}
}
else
{
auto b_pack = block_tile(i_j_idx);
block_tile_cast(i_j_idx) = type_convert<BDqDataType>(
type_convert<float>(b_pack) * bit_cast<float>(b_scale_uint));
}
});
});
}
}
template <typename WindowType, typename TileType, typename ElementwiseFunc>
CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window,
const TileType& block_tile,
const ElementwiseFunc& element_func) const
{
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, block_tile);
Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func);
}
else
{
Base::LocalPrefill(lds_window, block_tile, element_func);
}
}
template <typename WindowType,
typename TileType,
typename TileTypeCast,
typename ElementwiseFunc>
CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window,
const TileType& block_tile,
const TileTypeCast& block_tile_cast,
const ElementwiseFunc& element_func) const
{
// Fill LDS and apply the scale if IsCastBeforeLDS
auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) {
if constexpr(IsCastBeforeLDS)
{
return b_block_tile_cast;
}
else
{
return b_block_tile_orig;
}
};
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BLDSType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast));
Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func);
}
else
{
Base::LocalPrefill(
lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func);
}
}
template <typename BlockGemmType,
typename AWindowType,
typename BWindowType,
typename QTileType>
CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm,
const AWindowType& a_lds_window,
const BWindowType& b_lds_window,
const QTileType& q_block_tile) const
{
// Load from LDS
// It can apply the scale and cast if we scale after reading from LDS
if constexpr(IsCastBeforeLDS)
{
block_gemm.LocalPrefetch(a_lds_window, b_lds_window);
}
else
{
block_gemm.LocalPrefetch(a_lds_window, b_lds_window, q_block_tile);
}
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -331,6 +469,8 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
index_t num_loop,
void* p_smem) const
{
// -----------------------------------------------------------------------------------------
// Pipeline checks
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -340,11 +480,8 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
"A/B/BQ Dram block window should have the same data type as appropriate "
"([A|B|BQ]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -393,6 +530,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){};
// This defines the scaled and casted block tile for B matrix.
// Effectively, it is used only if we scale and cast before writing to LDS.
auto bdq_block_tile = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeBRegTileDistribution<Problem>());
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
@@ -405,7 +547,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_fp4_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
@@ -419,137 +561,44 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
// auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){};
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// BDataType
auto b_block_tile = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeBRegTileDistribution<Problem>());
// prefetch stages
// Vmem -> Vgpr 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Vmem -> Vgpr 0 (Q matrix)
// Scale and cast tile before writing to LDS (if IsCastBeforeLDS)
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
constexpr auto idx1_js = tile_distributed_index<0>{};
constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans();
// Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 on
// gfx950
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<BDqDataType, half_t>)
{
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
}
else if constexpr(std::is_same_v<BDqDataType, bf16_t>)
{
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
}
else
{
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
}
};
auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) {
if constexpr(IsCastBeforeLDS)
{
return b_block_tile_cast;
}
else
{
return b_block_tile_orig;
}
};
auto apply_scale_func = [&]() {
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto scale = bq_block_tile(i_j_idx_scale);
auto b_scale_uint = uint32_t(scale.data) << 23;
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
if constexpr(idx1.impl_.at(0) % BPackedSize == 0)
{
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0)>{};
constexpr auto idx1_hi =
tile_distributed_index<idx1.impl_.at(0) + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = b_fp4_block_tile(i_j_idx);
auto cvt =
pk_mxfp4_to_compute_v2(b_pack, bit_cast<float>(b_scale_uint));
b_block_tile(i_j_idx_lo) = cvt.x;
b_block_tile(i_j_idx_hi) = cvt.y;
}
}
else
{
auto b_pack = b_fp4_block_tile(i_j_idx);
b_block_tile(i_j_idx) = type_convert<BDqDataType>(
type_convert<float>(b_pack) * bit_cast<float>(b_scale_uint));
}
});
});
};
if constexpr(IsCastBeforeLDS)
apply_scale_func();
// initialize C
// initialize C tile to zero
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
transpose_tile2d(b_shuffle_tmp, b_block_tile_);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func);
}
// Vgpr -> LDS 0
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
// Vmem -> Vgpr 1
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// If we scale and cast before writing to LDS,
// we need to read another tile of Q matrix from Vmem, then scale and cast tile
if constexpr(IsCastBeforeLDS)
{
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
apply_scale_func();
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_sync_lds();
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
block_sync_lds();
// LDS -> Vgpr 0
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
}
__builtin_amdgcn_sched_barrier(0);
// main body
@@ -560,58 +609,34 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
transpose_tile2d(b_shuffle_tmp, b_block_tile_);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func);
}
// Vgpr -> LDS
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
// Vmem -> Vgpr
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Vmem -> Vgpr (Q matrix)
// Scale and cast tile before writing to LDS (if IsCastBeforeLDS)
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
if constexpr(IsCastBeforeLDS)
apply_scale_func();
// Consume tile
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(IsCastBeforeLDS)
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
else
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
// LDS -> Vgpr
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
// b_block_stride +=1;
} while(i < (num_loop - 1));
}
// tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile);
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
@@ -621,50 +646,31 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
}
else
{
// If we scale and cast after reading from LDS,
// we didn't read the second tile of Q matrix from Vmem during prefetch stages,
// so we need to read the last tile here.
// This is not a problem because we have all block_gemm instructions to hide the
// latency.
if constexpr(!IsCastBeforeLDS)
{
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
}
// Consume second to last tile
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
transpose_tile2d(b_shuffle_tmp, b_block_tile_);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func);
}
// Vgpr -> LDS last tile
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
block_sync_lds();
if constexpr(IsCastBeforeLDS)
{
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
}
// LDS -> Vgpr last tile
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
// Consume last tile
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
}
@@ -690,13 +696,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
void* p_smem,
index_t n = 0) const
{
using BElementwise = std::conditional_t<IsCastBeforeLDS, BDqDataType, BDataType>;
ck_tile::ignore = n;
ck_tile::ignore = n;
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BElementwise& b) { return b; },
[](const BLDSType& b) { return b; },
bq_dram_block_window_tmp,
num_loop,
p_smem);