No need to specify SrcDataType in load_and_convert_tile as WarpWindow knows its DataType

This commit is contained in:
Sami Aario
2025-12-16 10:24:23 +00:00
parent 514035e6cf
commit 3d55a1e682
14 changed files with 75 additions and 88 deletions

View File

@@ -28,15 +28,14 @@ struct ConverterLoader
}
};
template <typename SrcDataType,
typename DstDataType,
template <typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
if constexpr(std::is_same_v<typename WarpWindow::Base::DataType, pk_int4_t>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
ConverterLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);

View File

@@ -228,10 +228,10 @@ struct BlockUniversalGemmAsBsCr
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
load_and_convert_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_and_convert_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_and_convert_tile<BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -294,10 +294,10 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_and_convert_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_and_convert_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_and_convert_tile<BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B
@@ -425,10 +425,10 @@ struct BlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_and_convert_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_lds_gemm_window);
load_and_convert_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_lds_gemm_window);
load_and_convert_tile<ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_and_convert_tile<BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
}
// C += A * B

View File

@@ -64,8 +64,7 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename SrcDataType,
typename DstDataType,
template <typename DstDataType,
index_t UnaryOpSize = 8,
typename DstBlockTile,
typename SrcTileWindow,
@@ -74,8 +73,7 @@ struct GemmPipelineAgBgCrImplBase
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_and_convert_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile,
dram_tile_window);
load_and_convert_tile<DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}

View File

@@ -627,7 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// // Prefetch A0
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::template GlobalPrefetch<BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
// Prefill A0
@@ -652,7 +652,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
do
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::template GlobalPrefetch<BTypeToUse, UnaryOpSize_>(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
Base::GlobalPrefetch(
@@ -666,7 +666,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
HotLoopScheduler();
}
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::template GlobalPrefetch<BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
Base::GlobalPrefetch(
@@ -687,7 +687,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
if constexpr(TailNum == TailNumber::Even)
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::template GlobalPrefetch<BTypeToUse, UnaryOpSize_>(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
block_weight_preshuffle(

View File

@@ -261,10 +261,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_and_convert_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
load_and_convert_tile<ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_and_convert_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
load_and_convert_tile<ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}

View File

@@ -248,10 +248,10 @@ struct AQuantBlockUniversalGemmAsBsCr
// while ADatatype might not be the same as BDataType at the time of problem
// initialization, we can safely use BDataType here because when A would be int4 we will
// ensure A is converted to BDataType prior to loading
load_and_convert_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_and_convert_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<ComputeDataType, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_and_convert_tile<ComputeDataType, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B

View File

@@ -258,11 +258,11 @@ struct BQuantBlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_and_convert_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_and_convert_tile<ComputeDataType, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_and_convert_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<ComputeDataType, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B

View File

@@ -198,10 +198,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
const ADramWindow& a_dram_window)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_and_convert_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile,
a_dram_window);
load_and_convert_tile<DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
}
template <typename BDramWindow, typename BBlockTile_>
@@ -209,10 +207,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_and_convert_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile,
b_dram_window);
load_and_convert_tile<DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <bool HasHotLoop,
@@ -347,9 +343,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(
Base::template GlobalPrefetch<AQDataType>(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
Base::template GlobalPrefetch<BQDataType, BQDataType>(
Base::template GlobalPrefetch<BQDataType>(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -436,10 +432,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(aq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<AQDataType>(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
Base::template GlobalPrefetch<BQDataType, BQDataType>(bq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<BQDataType>(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
@@ -471,10 +467,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
}
else
{
Base::template GlobalPrefetch<AQDataType, AQDataType>(aq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<AQDataType>(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
Base::template GlobalPrefetch<BQDataType, BQDataType>(bq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<BQDataType>(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
block_gemm(c_block_tile,

View File

@@ -175,10 +175,8 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
const DramTileWindowStep& dram_tile_window_step)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_and_convert_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile,
a_dram_window);
load_and_convert_tile<DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
move_tile_window(a_dram_window, dram_tile_window_step);
}
@@ -286,9 +284,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// Global prefetch initialization - DRAM to VGPRs
LoadAndConvertATile(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BDataType>(
Base::template GlobalPrefetch<BDataType>(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(
Base::template GlobalPrefetch<AQDataType>(
aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -321,10 +319,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
LoadAndConvertATile(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BDataType>(b_block_tiles.get(number<prefetch_idx>{}),
Base::template GlobalPrefetch<BDataType>(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(aq_block_tiles.get(number<prefetch_idx>{}),
Base::template GlobalPrefetch<AQDataType>(aq_block_tiles.get(number<prefetch_idx>{}),
aq_copy_dram_window,
aq_dram_tile_window_step);
});
@@ -381,10 +379,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
LoadAndConvertATile(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BDataType>(b_block_tiles.get(number<prefetch_idx>{}),
Base::template GlobalPrefetch<BDataType>(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(aq_block_tiles.get(number<prefetch_idx>{}),
Base::template GlobalPrefetch<AQDataType>(aq_block_tiles.get(number<prefetch_idx>{}),
aq_copy_dram_window,
aq_dram_tile_window_step);
});

View File

@@ -169,10 +169,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
const ADramWindow& a_dram_window)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_and_convert_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile,
a_dram_window);
load_and_convert_tile<DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
}
template <bool HasHotLoop,
@@ -277,8 +275,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// DRAM prefetch (global read 0)
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(
Base::template GlobalPrefetch<BDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType>(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -309,7 +307,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
@@ -352,8 +350,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType, AQDataType>(aq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<BDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<AQDataType>(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
@@ -379,7 +377,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
else
{
Base::template GlobalPrefetch<AQDataType, AQDataType>(aq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<AQDataType>(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
block_gemm(

View File

@@ -183,10 +183,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_and_convert_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile,
b_dram_window);
load_and_convert_tile<DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
@@ -202,7 +200,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
else
{
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<OverrideBDataType>(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
}
}
@@ -312,10 +310,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
: make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<OverrideBDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(
Base::template GlobalPrefetch<OverrideBDataType>(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -345,7 +343,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<OverrideBDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
@@ -389,10 +387,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<OverrideBDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(bq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<OverrideBDataType>(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
@@ -418,7 +416,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
else
{
Base::template GlobalPrefetch<ADataType, OverrideBDataType>(bq_block_tile[(currIdx + 1) % 2],
Base::template GlobalPrefetch<OverrideBDataType>(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
block_gemm(

View File

@@ -419,8 +419,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
// prefetch
// global read 0
// auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){};
Base::template GlobalPrefetch<ADataType, BDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<ADataType, BDataType>(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType>(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// BDataType
auto b_block_tile = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeBRegTileDistribution<Problem>());
@@ -480,8 +480,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::template GlobalPrefetch<ADataType, BDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<ADataType, BDataType>(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType>(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
@@ -544,8 +544,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::template GlobalPrefetch<ADataType, BDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<ADataType, BDataType>(
Base::template GlobalPrefetch<BDataType>(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType>(
b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
bq_block_tile = load_tile(bq_copy_dram_window);

View File

@@ -349,7 +349,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -430,7 +430,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -455,7 +455,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -503,7 +503,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});

View File

@@ -335,8 +335,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<ADataType, UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
@@ -421,7 +421,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -458,7 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -516,7 +516,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<BDataType, ADataType, UnaryOpSize_>(
load_and_convert_tile<ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});