really padding N for B matrix

This commit is contained in:
letaoqin
2025-03-10 10:52:59 +00:00
parent e7f8544bcd
commit 61172ea4bb
4 changed files with 37 additions and 13 deletions

View File

@@ -129,7 +129,16 @@ void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl
}
}
}
int GetPreShufflePadded(int K) { return (K + ShufflePadded - 1) / ShufflePadded * ShufflePadded; }
int GetKPreShufflePadded(int K)
{
return (K + KShufflePadded - 1) / KShufflePadded * KShufflePadded;
}
int GetNPreShufflePadded(int N)
{
return (N + NShufflePadded - 1) / NShufflePadded * NShufflePadded;
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
@@ -268,12 +277,14 @@ int main(int argc, char* argv[])
}
};
auto Knew = GetPreShufflePadded(K);
auto Knew = GetKPreShufflePadded(K);
auto StrideBnew = Knew;
auto Nnew = GetNPreShufflePadded(N);
std::cout << "Knew: " << Knew << " Nnew: " << Nnew << std::endl;
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size
f_host_tensor_descriptor(Knew, Nnew, StrideBnew, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
@@ -289,8 +300,8 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;

View File

@@ -146,8 +146,8 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
virtual int GetPreShuffleParameters() = 0;
};
#define ShufflePadded 256
#define KShufflePadded 256
#define NShufflePadded 128
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -214,7 +214,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto CalculateBKShufflePadded(index_t K)
{
return (K + ShufflePadded - 1) / ShufflePadded * ShufflePadded;
return (K + KShufflePadded - 1) / KShufflePadded * KShufflePadded;
}
__host__ __device__ static auto CalculateBNShufflePadded(index_t N)
{
return (N + NShufflePadded - 1) / NShufflePadded * NShufflePadded;
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -604,7 +609,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)},
BN0Shuffled{CalculateBN0Shuffled((N + 128 - 1) / 128 * 128)},
BN0Shuffled{CalculateBN0Shuffled(CalculateBNShufflePadded(N_))},
BK0Shuffled{CalculateBK0Shuffled(CalculateBKShufflePadded(K_))}
{
}

View File

@@ -57,7 +57,14 @@ void preShuffleBuffer(
}
}
int GetPreShufflePadded(int K) { return (K + ShufflePadded - 1) / ShufflePadded * ShufflePadded; }
int GetKPreShufflePadded(int K)
{
return (K + KShufflePadded - 1) / KShufflePadded * KShufflePadded;
}
int GetNPreShufflePadded(int N)
{
return (N + NShufflePadded - 1) / NShufflePadded * NShufflePadded;
}
template <typename ADataType,
typename BDataType,
@@ -103,14 +110,15 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto Knew = GetPreShufflePadded(K);
auto Knew = GetKPreShufflePadded(K);
auto StrideBnew = Knew;
auto Nnew = GetNPreShufflePadded(N);
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_preshuffled_mfma16(
f_host_tensor_descriptor(Knew, N, StrideBnew, BLayout{})); // use layout only for size
f_host_tensor_descriptor(Knew, Nnew, StrideBnew, BLayout{})); // use layout only for size
Tensor<BDataType> b_preshuffled_mfma32(
f_host_tensor_descriptor(Knew, N, StrideBnew, BLayout{})); // use layout only for size
f_host_tensor_descriptor(Knew, Nnew, StrideBnew, BLayout{})); // use layout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));