tune fp8 example

This commit is contained in:
lalala-sh
2025-05-06 08:46:38 +00:00
parent 0ab978584d
commit abff33eaab

View File

@@ -157,12 +157,15 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 256;
static constexpr ck::index_t MXDLPerWave = 16;
static constexpr ck::index_t NXDLPerWave = 4;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 256;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
@@ -190,7 +193,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
// clang-format on
@@ -308,7 +311,7 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.1, 0.1});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});