debug save

This commit is contained in:
mtgu0705
2025-05-14 09:33:24 -05:00
parent 102151ebcf
commit efdd420742

View File

@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// clang-format on
#else
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 16;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -163,14 +163,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, 128,
ScaleBlockSize, 64,
MPerBlock, 16, 128,
32, 32,
16, 16,
8, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>,
1, 1,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
// clang-format on
#endif
@@ -408,12 +408,12 @@ int main(int argc, char* argv[])
if(k % 2 == 0)
{
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%f ", ck::type_convert<float>(f4));
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%f ", ck::type_convert<float>(f4));
printf("%.2f ", ck::type_convert<float>(f4));
}
}
printf("\n");
@@ -428,7 +428,7 @@ int main(int argc, char* argv[])
{
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
printf("%f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
printf("%.2f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
}
printf("\n");
}
@@ -446,12 +446,12 @@ int main(int argc, char* argv[])
if(k % 2 == 0)
{
ck::f4_t f4 = f4x2 >> 4 & 0xf;
printf("%f ", ck::type_convert<float>(f4));
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = f4x2 >> 0 & 0xf;
printf("%f ", ck::type_convert<float>(f4));
printf("%.2f ", ck::type_convert<float>(f4));
}
}
printf("\n");
@@ -466,7 +466,7 @@ int main(int argc, char* argv[])
{
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
printf("%f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
printf("%.2f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
}
printf("\n");
}
@@ -598,7 +598,7 @@ int main(int argc, char* argv[])
{
for(int n = 0; n < N; ++n)
{
printf("%f ", ck::type_convert<float>(e_t_n_device_result(t, n)));
printf("%.2f ", ck::type_convert<float>(e_t_n_device_result(t, n)));
}
printf("\n");
}
@@ -608,7 +608,7 @@ int main(int argc, char* argv[])
{
for(int n = 0; n < N; ++n)
{
printf("%f ", ck::type_convert<float>(e_t_n_host_result(t, n)));
printf("%.2f ", ck::type_convert<float>(e_t_n_host_result(t, n)));
}
printf("\n");
}