Improvements for: Groupwise scaling along M for FP8 gemm (#2095)

* fix blockwise fp8 kernels

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* wip, < 128 not working

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* fix < 128

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* reduce diff

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* review comments

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* support partial n blocks

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* fix build errors

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

---------

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-02-27 22:39:29 -05:00
committed by GitHub
parent ca4fdbea70
commit df18f5e4f5
4 changed files with 198 additions and 131 deletions

View File

@@ -557,13 +557,13 @@ bool verify(const Options<RasterOrderOptions> &options) {
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k)
)
);

View File

@@ -396,14 +396,17 @@ template <typename GroupScaleConfig>
void initialize(const Options<RasterOrderOptions> &options) {
using TileShape = typename GroupScaleConfig::TileShape;
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;
assert(options.m % ScaleGranularityM == 0);
assert(options.n % ScaleGranularityN == 0);
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM;
auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN;
auto blockscale_k = cute::get<2>(blockscale_shape);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
@@ -575,6 +578,8 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
//
// Compute reference output
//
const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile;
const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile;
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
@@ -582,6 +587,8 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM;
auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN;
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
@@ -617,14 +624,14 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
cute::make_shape(groupscale_m, blockscale_k, options.l),
cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
cute::make_shape(groupscale_n, blockscale_k, options.l),
cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k)
)
);
@@ -708,6 +715,31 @@ int run(Options<RasterOrderOptions> &options)
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
bool skip = false;
if (options.m % ScaleGranularityM != 0) {
std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}
if (options.n % ScaleGranularityN != 0) {
std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}
if (options.k % size<2>(TileShape{}) != 0) {
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
skip = true;
}
if (!skip) std::cout << "Running: " << std::endl;
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
if (skip) return -1;
initialize<GroupScaleConfig>(options);
// Instantiate CUTLASS kernel depending on templates
@@ -768,10 +800,6 @@ int run(Options<RasterOrderOptions> &options)
raster = "Along M";
}
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;

View File

@@ -217,15 +217,19 @@ void gett_mainloop(
}
}
int64_t block_m = m / kBlockM;
int64_t block_n = n / kBlockN;
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);
const int M = cute::size<0>(mainloop_params.A.layout());
const int N = cute::size<0>(mainloop_params.B.layout());
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA);
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB);
assert(ScaleGranularityM && M % ScaleGranularityM == 0
&& "ScaleGranularityM must divide M");
assert(ScaleGranularityN && N % ScaleGranularityN == 0
&& "ScaleGranularityN must divide N");
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
cute::Tensor blockscale_A = domain_offset(
make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l));
cute::Tensor blockscale_B = domain_offset(
make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l));
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
@@ -257,9 +261,12 @@ void gett_mainloop(
}
}
int m_size = std::min(static_cast<int64_t>(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m);
int n_size = std::min(static_cast<int64_t>(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n);
// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int m_b = 0; m_b < m_size; ++m_b) {
for (int n_b = 0; n_b < n_size; ++n_b) {
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
}
}
@@ -269,9 +276,9 @@ void gett_mainloop(
// (b) Zero-out partial temporary (acc_temp),
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int m_b = 0; m_b < m_size; ++m_b) {
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int n_b = 0; n_b < n_size; ++n_b) {
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];

View File

@@ -118,8 +118,8 @@ struct CollectiveMma<
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
// Two threads per CTA are producers (1 for operand tile and 1 for scales)
static constexpr int NumProducerThreadEvents = 2;
// Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`)
static constexpr int NumProducerThreadEvents = 33;
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_;
@@ -150,10 +150,9 @@ struct CollectiveMma<
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
// Block scaling gmem-to-smem copy atom
using BlockScaleCopyTypeA = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>;
using BlockScaleCopyTypeB = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>;
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeA>, ElementBlockScale>;
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeB>, ElementBlockScale>;
// we can have partial tiles in M or N, so don't vectorize those loads
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
// Block scaling smem layout
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
@@ -217,7 +216,6 @@ struct CollectiveMma<
uint32_t tma_transaction_bytes = TmaTransactionBytes;
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
uint32_t mma_promotion_interval = 4;
// Block scaling factors for A and B
ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_B;
@@ -263,7 +261,6 @@ struct CollectiveMma<
transaction_bytes,
transaction_bytes_mk,
transaction_bytes_nk,
args.mma_promotion_interval,
args.ptr_scale_A,
args.ptr_scale_B
};
@@ -283,11 +280,15 @@ struct CollectiveMma<
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{});
implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval);
implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval);
// We expect full tiles in K
implementable = implementable && (K % size<2>(TileShape{}) == 0);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
@@ -331,18 +332,18 @@ struct CollectiveMma<
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
auto tK = get<3>(gA_mkl.shape());
// Make the tiled views of scale tensors
auto scaleA_shape = make_shape(shape<2>(gA_mkl), Int<ScaleMsPerTile>{}, shape<3>(gA_mkl), shape<4>(gA_mkl)); // (m,ScaleMsPerTile,k,l)
auto scaleB_shape = make_shape(shape<2>(gB_nkl), Int<ScaleNsPerTile>{}, shape<3>(gB_nkl), shape<4>(gB_nkl)); // (n,ScaleNsPerTile,k,l)
auto scale_dA = compact_order(scaleA_shape, Step<_2,_0,_1,_3>{});
auto scale_dB = compact_order(scaleB_shape, Step<_2,_0,_1,_3>{});
auto scaleA_layout = make_layout(scaleA_shape, scale_dA);
auto scaleB_layout = make_layout(scaleB_shape, scale_dB);
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
auto scaleB_shape = make_shape(N / ScaleGranularityN, tK, L); // (scale_n,k,l)
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_0, _1, _2>{});
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l)
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,ScaleNsPerTile,k,l)
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (scale_n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
}
@@ -367,103 +368,134 @@ struct CollectiveMma<
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k)
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k)
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
Tensor mScaleA_mkl = get<2>(load_inputs);
Tensor mScaleB_nkl = get<3>(load_inputs);
auto scales_m = get<0>(mScaleA_mkl.shape());
auto scales_n = get<0>(mScaleB_nkl.shape());
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
Tensor cScaleB_nkl = make_identity_tensor(mScaleB_nkl.shape());
Tensor gScaleA = local_tile(
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
Tensor cScaleA = local_tile(
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
make_coord(m_coord,_,l_coord));
Tensor gScaleB = local_tile(
mScaleB_nkl, make_tile(Int<ScaleNsPerTile>{}),
make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1)
Tensor cScaleB = local_tile(
cScaleB_nkl, make_tile(Int<ScaleNsPerTile>{}),
make_coord(n_coord,_,l_coord));
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
Layout<Shape<_32>>{}, Layout<Shape<_1>>{});
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
Layout<Shape<_32>>{}, Layout<Shape<_1>>{});
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
Tensor tBcB_ScaleB = thr_scale_copy_b.partition_S(cScaleB);
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
Tensor tBpB_ScaleB = make_tensor<bool>(shape(tBsB_ScaleB(_,_,0)));
#pragma unroll
for (int i = 0; i < size(tApA_ScaleA); ++i) {
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) <
std::min(scales_m, (m_coord + 1) * ScaleMsPerTile);
}
#pragma unroll
for (int i = 0; i < size(tBpB_ScaleB); ++i) {
tBpB_ScaleB(i) = get<0>(tBcB_ScaleB(i)) <
std::min(scales_n, (n_coord + 1) * ScaleNsPerTile);
}
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Prepare the TMA loads for A and B
// Copy gmem to smem for *k_tile_iter
//
int write_stage = smem_pipe_write.index();
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
// Copy operands A and B from global memory to shared memory
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
// Copy scale tensors from global memory to shared memory
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
copy_if(scale_copy_b, tBpB_ScaleB, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage));
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
++k_tile_iter;
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
Tensor mScaleA_mkl = get<2>(load_inputs);
Tensor mScaleB_nkl = get<3>(load_inputs);
Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1)
Tensor gScaleB = mScaleB_nkl(n_coord,_,_,l_coord); // (1,ScaleNsPerTile,k,1)
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleMsPerTile>>>{}); // (1,ScaleMsPerTile,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleNsPerTile>>>{}); // (1,ScaleNsPerTile,1)
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
int write_stage = smem_pipe_write.index();
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
// Copy operands A and B from global memory to shared memory
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
// Copy scale tensors from global memory to shared memory
copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage));
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
// Advance smem_pipe_write
++smem_pipe_write;
}
}