mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
fix for directloads on non last dim
This commit is contained in:
@@ -140,9 +140,9 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
"Direct load transfer does not support datatypes conversion. Source and "
|
||||
"destination data types must be the same.");
|
||||
|
||||
static_assert(
|
||||
DstVectorDim == nDim - 1,
|
||||
"Direct load transfer requires the destination vector dimension to be the last one.");
|
||||
// static_assert(
|
||||
// DstVectorDim == nDim - 1,
|
||||
// "Direct load transfer requires the destination vector dimension to be the last one.");
|
||||
|
||||
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
|
||||
"When loading more than one element per thread at once, the contiguous "
|
||||
|
||||
@@ -933,12 +933,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
#endif
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(DirectLoad)
|
||||
if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
}
|
||||
else if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
|
||||
@@ -1633,7 +1639,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcVectorDim, // enforcer earlier
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
|
||||
|
||||
Reference in New Issue
Block a user