mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 08:50:09 +00:00
Improvements for: Groupwise scaling along M for FP8 gemm (#2095)
* fix blockwise fp8 kernels Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * wip, < 128 not working Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix < 128 Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * reduce diff Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * review comments Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * support partial n blocks Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> * fix build errors Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -557,13 +557,13 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
|
||||
cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
|
||||
cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
@@ -396,14 +396,17 @@ template <typename GroupScaleConfig>
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
using TileShape = typename GroupScaleConfig::TileShape;
|
||||
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
|
||||
const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;
|
||||
|
||||
assert(options.m % ScaleGranularityM == 0);
|
||||
assert(options.n % ScaleGranularityN == 0);
|
||||
|
||||
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
|
||||
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
|
||||
auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM;
|
||||
auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN;
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
@@ -575,6 +578,8 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile;
|
||||
const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile;
|
||||
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
@@ -582,6 +587,8 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM;
|
||||
auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN;
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
@@ -617,14 +624,14 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
|
||||
cute::make_shape(groupscale_m, blockscale_k, options.l),
|
||||
cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
|
||||
cute::make_shape(groupscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
@@ -708,6 +715,31 @@ int run(Options<RasterOrderOptions> &options)
|
||||
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
bool skip = false;
|
||||
|
||||
if (options.m % ScaleGranularityM != 0) {
|
||||
std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.n % ScaleGranularityN != 0) {
|
||||
std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.k % size<2>(TileShape{}) != 0) {
|
||||
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (!skip) std::cout << "Running: " << std::endl;
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
|
||||
if (skip) return -1;
|
||||
|
||||
initialize<GroupScaleConfig>(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
@@ -768,10 +800,6 @@ int run(Options<RasterOrderOptions> &options)
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
|
||||
@@ -217,15 +217,19 @@ void gett_mainloop(
|
||||
}
|
||||
}
|
||||
|
||||
int64_t block_m = m / kBlockM;
|
||||
int64_t block_n = n / kBlockN;
|
||||
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
|
||||
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);
|
||||
const int M = cute::size<0>(mainloop_params.A.layout());
|
||||
const int N = cute::size<0>(mainloop_params.B.layout());
|
||||
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA);
|
||||
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB);
|
||||
assert(ScaleGranularityM && M % ScaleGranularityM == 0
|
||||
&& "ScaleGranularityM must divide M");
|
||||
assert(ScaleGranularityN && N % ScaleGranularityN == 0
|
||||
&& "ScaleGranularityN must divide N");
|
||||
|
||||
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
|
||||
const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
|
||||
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
|
||||
assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
|
||||
cute::Tensor blockscale_A = domain_offset(
|
||||
make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l));
|
||||
cute::Tensor blockscale_B = domain_offset(
|
||||
make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l));
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
@@ -257,9 +261,12 @@ void gett_mainloop(
|
||||
}
|
||||
}
|
||||
|
||||
int m_size = std::min(static_cast<int64_t>(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m);
|
||||
int n_size = std::min(static_cast<int64_t>(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n);
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
for (int m_b = 0; m_b < m_size; ++m_b) {
|
||||
for (int n_b = 0; n_b < n_size; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
@@ -269,9 +276,9 @@ void gett_mainloop(
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int m_b = 0; m_b < m_size; ++m_b) {
|
||||
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
for (int n_b = 0; n_b < n_size; ++n_b) {
|
||||
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
|
||||
@@ -118,8 +118,8 @@ struct CollectiveMma<
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// Two threads per CTA are producers (1 for operand tile and 1 for scales)
|
||||
static constexpr int NumProducerThreadEvents = 2;
|
||||
// Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`)
|
||||
static constexpr int NumProducerThreadEvents = 33;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_;
|
||||
@@ -150,10 +150,9 @@ struct CollectiveMma<
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using BlockScaleCopyTypeA = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>;
|
||||
using BlockScaleCopyTypeB = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>;
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeA>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeB>, ElementBlockScale>;
|
||||
// we can have partial tiles in M or N, so don't vectorize those loads
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
@@ -217,7 +216,6 @@ struct CollectiveMma<
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
// Block scaling factors for A and B
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
@@ -263,7 +261,6 @@ struct CollectiveMma<
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.mma_promotion_interval,
|
||||
args.ptr_scale_A,
|
||||
args.ptr_scale_B
|
||||
};
|
||||
@@ -283,11 +280,15 @@ struct CollectiveMma<
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
|
||||
constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{});
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval);
|
||||
implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval);
|
||||
|
||||
// We expect full tiles in K
|
||||
implementable = implementable && (K % size<2>(TileShape{}) == 0);
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
@@ -331,18 +332,18 @@ struct CollectiveMma<
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
auto tK = get<3>(gA_mkl.shape());
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(shape<2>(gA_mkl), Int<ScaleMsPerTile>{}, shape<3>(gA_mkl), shape<4>(gA_mkl)); // (m,ScaleMsPerTile,k,l)
|
||||
auto scaleB_shape = make_shape(shape<2>(gB_nkl), Int<ScaleNsPerTile>{}, shape<3>(gB_nkl), shape<4>(gB_nkl)); // (n,ScaleNsPerTile,k,l)
|
||||
auto scale_dA = compact_order(scaleA_shape, Step<_2,_0,_1,_3>{});
|
||||
auto scale_dB = compact_order(scaleB_shape, Step<_2,_0,_1,_3>{});
|
||||
auto scaleA_layout = make_layout(scaleA_shape, scale_dA);
|
||||
auto scaleB_layout = make_layout(scaleB_shape, scale_dB);
|
||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||
auto scaleB_shape = make_shape(N / ScaleGranularityN, tK, L); // (scale_n,k,l)
|
||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_0, _1, _2>{});
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,ScaleNsPerTile,k,l)
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (scale_n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
@@ -367,103 +368,134 @@ struct CollectiveMma<
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k)
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k)
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||
auto scales_n = get<0>(mScaleB_nkl.shape());
|
||||
|
||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||
Tensor cScaleB_nkl = make_identity_tensor(mScaleB_nkl.shape());
|
||||
|
||||
Tensor gScaleA = local_tile(
|
||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||
Tensor cScaleA = local_tile(
|
||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord));
|
||||
Tensor gScaleB = local_tile(
|
||||
mScaleB_nkl, make_tile(Int<ScaleNsPerTile>{}),
|
||||
make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1)
|
||||
Tensor cScaleB = local_tile(
|
||||
cScaleB_nkl, make_tile(Int<ScaleNsPerTile>{}),
|
||||
make_coord(n_coord,_,l_coord));
|
||||
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32>>{}, Layout<Shape<_1>>{});
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_32>>{}, Layout<Shape<_1>>{});
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBcB_ScaleB = thr_scale_copy_b.partition_S(cScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||
Tensor tBpB_ScaleB = make_tensor<bool>(shape(tBsB_ScaleB(_,_,0)));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) <
|
||||
std::min(scales_m, (m_coord + 1) * ScaleMsPerTile);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tBpB_ScaleB); ++i) {
|
||||
tBpB_ScaleB(i) = get<0>(tBcB_ScaleB(i)) <
|
||||
std::min(scales_n, (n_coord + 1) * ScaleNsPerTile);
|
||||
}
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy_if(scale_copy_b, tBpB_ScaleB, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
++k_tile_iter;
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
|
||||
Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1)
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,_,l_coord); // (1,ScaleNsPerTile,k,1)
|
||||
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleMsPerTile>>>{}); // (1,ScaleMsPerTile,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleNsPerTile>>>{}); // (1,ScaleNsPerTile,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user