Revert "[BlockScale GEMM] FP8 Blockscale GEMM optimization and ckProfiler (#1913)" (#1933)

This reverts commit 06e1eee9bbb737f0ee3b2374f1857838c2b8ef3f.

[ROCm/composable_kernel commit: ef16010273]
This commit is contained in:
asleepzzz
2025-03-03 23:17:39 +08:00
committed by GitHub
parent a81cf05757
commit de1a2544c6
18 changed files with 663 additions and 1018 deletions

View File

@@ -55,7 +55,7 @@ using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
16, 128,
256, 16, 16,
128, 128,
128, 16, 16,
16, 16,
1, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 16, 1, 16>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on
int main(int argc, char* argv[])
@@ -80,12 +80,11 @@ int main(int argc, char* argv[])
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool flush_cache = true;
// GEMM shape
ck::index_t M = 128;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
@@ -101,7 +100,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 8)
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -111,19 +110,16 @@ int main(int argc, char* argv[])
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
flush_cache = std::stoi(argv[7]);
StrideA = K;
StrideB = K;
StrideE = N;
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: M, N, K\n");
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0);
}
@@ -186,15 +182,9 @@ int main(int argc, char* argv[])
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 4:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
default:
@@ -204,16 +194,6 @@ int main(int argc, char* argv[])
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
}
#endif
#if 0
for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){
float row_sum = .0;
for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){
printf("%lf ",a1_m_k(im, ik));
row_sum += a1_m_k(im, ik);
}
printf("sum: %lf\n", row_sum * 128);
}
#endif
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
@@ -259,24 +239,12 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float ave_time = .0;
if(flush_cache)
{
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
ave_time = invoker.Run(argument,
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
}
else
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;