mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK][Examples] Extending support for rdna3/4 part 2:
-example_batched_gemm_xdl_int8 -example_batched_gemm_xdl_fp8_rowwise_v3 -example_batched_gemm_xdl_fp32 -example_batched_gemm_xdl_bf16 -example_batched_gemm_xdl_bf16_v3 -example_batched_gemm_xdl_fp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 *fixing return value to return 0 as success in above examples. Fixing cmdline parameters in: -example_sparse_embedding3_forward_layernorm -example_elementwise_binary_4D_fp16 -elementwise_scale_permute_amax_2D_fp16_fp8 Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
This commit is contained in:
committed by
Michał Kulikowski
parent
1d4db30af9
commit
7259b9c4db
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
@@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
|
||||
@@ -68,10 +68,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -89,11 +89,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
S<8>, // CDEShuffleBlockTransferScalarPerVectors
|
||||
S<4>, // CDEShuffleBlockTransferScalarPerVectors
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler
|
||||
ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion
|
||||
>;
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
@@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
@@ -50,9 +52,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
|
||||
@@ -74,10 +74,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
64, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -95,7 +95,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors
|
||||
S<4, 4, 1>, // CDEShuffleBlockTransferScalarPerVectors
|
||||
ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler
|
||||
ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion
|
||||
F8 // ComputeTypeA
|
||||
@@ -103,4 +103,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
|
||||
#include "run_batched_gemm_example_rowwise.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_rowwise_example(argc, argv); }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
@@ -96,4 +98,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
#define BUILD_INT4_EXAMPLE
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
@@ -48,9 +50,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <random>
|
||||
|
||||
#pragma once
|
||||
@@ -243,7 +245,7 @@ bool run_batched_gemm_example(int argc, char* argv[])
|
||||
problem_size.N,
|
||||
problem_size.K,
|
||||
problem_size.batch_count);
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
problem_size.stride_A = problem_size.K;
|
||||
|
||||
@@ -346,7 +346,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
@@ -566,7 +566,7 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
problem_size.stride_A = problem_size.K;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <random>
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -77,17 +77,21 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// Use default value
|
||||
}
|
||||
else if(argc == 4)
|
||||
else if(argc == 5)
|
||||
{
|
||||
num_rows = atoi(argv[1]);
|
||||
dim_mask = strtol(argv[2], nullptr, 0);
|
||||
index_length = atoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[1]);
|
||||
num_rows = std::stoi(argv[2]);
|
||||
dim_mask = strtol(argv[3], nullptr, 0);
|
||||
index_length = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1-3: num_rows dim_mask index_length" << std::endl;
|
||||
std::cout << "arg1: time kernel (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "arg2-4: num_rows dim_mask index_length" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ck::static_for<0, dims.Size(), 1>{}([&](auto I) {
|
||||
if(dim_mask & (1 << I.value))
|
||||
{
|
||||
@@ -160,11 +164,10 @@ int main(int argc, char* argv[])
|
||||
<< std::endl
|
||||
<< std::flush;
|
||||
|
||||
bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get());
|
||||
|
||||
if(!is_supported)
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters are not supported" << std::endl;
|
||||
std::cerr << device_instance.GetTypeString() << " does not support this problem"
|
||||
<< std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -56,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
@@ -51,6 +51,8 @@ int main(int argc, char* argv[])
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
@@ -60,30 +62,21 @@ int main(int argc, char* argv[])
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
nchw[0] = std::stoi(argv[3]);
|
||||
nchw[1] = std::stoi(argv[4]);
|
||||
nchw[2] = std::stoi(argv[5]);
|
||||
nchw[3] = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 5)
|
||||
{
|
||||
nchw[0] = std::stoi(argv[1]);
|
||||
nchw[1] = std::stoi(argv[2]);
|
||||
nchw[2] = std::stoi(argv[3]);
|
||||
nchw[3] = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1 to 4: N, C, H, W" << std::endl;
|
||||
|
||||
return 1;
|
||||
printf("arg3-6: N, C, H, W (default 16, 128, 32, 64)\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
@@ -119,6 +119,11 @@ int main(int argc, char* argv[])
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
const float scale = 2.f;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
@@ -128,22 +133,19 @@ int main(int argc, char* argv[])
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
}
|
||||
else if(argc == 5)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
M = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const float scale = 2.f;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
if(argc == 3)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
K = std::stoi(argv[2]);
|
||||
printf("arg3-4: M(default=1024), K(default=1024)\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::array<ck::index_t, 2> dims = {M, K};
|
||||
|
||||
Reference in New Issue
Block a user