Post-merge fix of PR 1300 (#1313)

* add f8 gemm with multiD for both row/col wise

* change compute_type to fp8

* changed tuning parameters in the example

* add rcr example

* post-merge fix

* fix

* reduce init range
This commit is contained in:
zjing14
2024-06-01 00:46:41 -05:00
committed by GitHub
parent 34f3dfdd61
commit 6fb1f4e03f
3 changed files with 14 additions and 14 deletions

View File

@@ -59,7 +59,7 @@ struct MultiplyMultiply
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::bhalf_t>(x0_f);
e = ck::type_convert<ck::half_t>(x0_f);
}
};
@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
@@ -164,10 +164,10 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});