fix for directloads on non last dim

This commit is contained in:
Jakub Piasecki
2026-01-26 11:38:23 +00:00
parent 391e06e070
commit 6fda7ab9bb
2 changed files with 11 additions and 5 deletions

View File

@@ -140,9 +140,9 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
// static_assert(
// DstVectorDim == nDim - 1,
// "Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "

View File

@@ -933,12 +933,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
#endif
// B matrix in LDS memory, dst of blockwise copy
if constexpr(DirectLoad)
if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 2)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
}
else if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 1)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
}
else if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
@@ -1633,7 +1639,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcVectorDim, // enforcer earlier
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),