mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
WMMA support for GEMM reduce (#2823)
Added gemm + reduce instance library for RDNA4. This includes:
- New device implementation running GEMM and reduction kernel
- instances for wmma (xdl parity)
- examples for wmma (xdl parity)
- tests for existing xdl and wmma
[ROCm/composable_kernel commit: b25d4d684a]
This commit is contained in:
committed by
GitHub
parent
14d52a943c
commit
5e10274417
@@ -27,3 +27,16 @@ add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_bf16 gemm_xdl_spli
|
||||
add_example_executable(example_gemm_xdl_splitk_reduce_bf16A_i8B gemm_xdl_splitk_reduce_bf16A_i8B.cpp)
|
||||
|
||||
add_example_executable(example_gemm_xdl_splitk_reduce_bfp16 gemm_xdl_splitk_reduce_bf16.cpp)
|
||||
|
||||
add_custom_target(example_splitK_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_splitk_reduce_bf16 gemm_wmma_splitk_reduce_bf16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16)
|
||||
|
||||
add_example_executable(example_gemm_wmma_splitk_reduce_bf16A_i8B gemm_wmma_splitk_reduce_bf16A_i8B.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16A_i8B)
|
||||
|
||||
add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_bf16 gemm_wmma_splitk_reduce_multi_d_bf16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_bf16)
|
||||
|
||||
add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_fp16 gemm_wmma_splitk_reduce_multi_d_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_fp16)
|
||||
|
||||
@@ -99,3 +99,85 @@ bool parse_cmd_args(int argc,
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
59
example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp
Normal file
59
example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using ReduceDataType = ck::bhalf_t;
|
||||
using D0DataType = ck::bhalf_t;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
using D0Layout = CLayout;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceWmmaGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
|
||||
ALayout, BLayout, DsLayout, CLayout,
|
||||
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128, 32,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_wmma_splitk_reduce_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); }
|
||||
59
example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp
Normal file
59
example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = int8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using ReduceDataType = float;
|
||||
using D0DataType = ck::bhalf_t;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceWmmaGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
|
||||
ALayout, BLayout, DsLayout, CLayout,
|
||||
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128, 32,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_wmma_splitk_reduce_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); }
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using ReduceDataType = float;
|
||||
using D0DataType = ck::bhalf_t;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
using D0Layout = CLayout;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
|
||||
ALayout, BLayout, DsLayout, CLayout,
|
||||
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128, 32,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); }
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using ReduceDataType = float;
|
||||
using D0DataType = ck::half_t;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
using D0Layout = CLayout;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
|
||||
ALayout, BLayout, DsLayout, CLayout,
|
||||
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmDefault,
|
||||
256,
|
||||
128, 256, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); }
|
||||
@@ -3,88 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
|
||||
191
example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc
Normal file
191
example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc
Normal file
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_wmma_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "init method: " << config.init_method << std::endl;
|
||||
std::cout << "KBatch: " << KBatch << std::endl;
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// device GEMM
|
||||
auto device_op = DeviceWmmaGemmInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto argument =
|
||||
device_op.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 0>{}, // empty D tensors
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 0>{}, // empty D strides
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
// Allocate workspace for split-K reduction if needed
|
||||
size_t workspace_size = device_op.GetWorkSpaceSize(argument.get());
|
||||
DeviceMem workspace_buf(workspace_size);
|
||||
std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl;
|
||||
if(workspace_size > 0)
|
||||
{
|
||||
argument->p_workspace_ = workspace_buf.GetDeviceBuffer();
|
||||
std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl;
|
||||
}
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "The runtime argument is not supported!" << std::endl;
|
||||
std::cout << "Debug info:" << std::endl;
|
||||
std::cout << " M=" << M << ", N=" << N << ", K=" << K << ", KBatch=" << KBatch
|
||||
<< std::endl;
|
||||
std::cout << " StrideA=" << StrideA << ", StrideB=" << StrideB << ", StrideC=" << StrideC
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
float ave_time = 0;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, false});
|
||||
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData,
|
||||
c_m_n_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E12 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E9 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_wmma_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config);
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ProblemSize>
|
||||
bool run_wmma_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto StrideD0 = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(
|
||||
f_host_tensor_descriptor(problem_size.M, problem_size.K, problem_size.StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(
|
||||
f_host_tensor_descriptor(problem_size.K, problem_size.N, problem_size.StrideB, BLayout{}));
|
||||
Tensor<D0DataType> d0_m_n(
|
||||
f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, D0Layout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "init method: " << config.init_method << std::endl;
|
||||
std::cout << "KBatch: " << KBatch << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
constexpr auto kNum_DTensors = DsDataType::Size();
|
||||
const std::array<const void*, kNum_DTensors> p_ds = {d0_m_n_device_buf.GetDeviceBuffer()};
|
||||
const std::array<ck::index_t, kNum_DTensors> d_strides = {problem_size.StrideC};
|
||||
|
||||
auto argument =
|
||||
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
problem_size.M,
|
||||
problem_size.N,
|
||||
problem_size.K,
|
||||
problem_size.StrideA,
|
||||
problem_size.StrideB,
|
||||
d_strides,
|
||||
problem_size.StrideC,
|
||||
problem_size.KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto workspace_size = gemm.GetWorkSpaceSize(argument.get());
|
||||
DeviceMem workspace_device_buf(workspace_size);
|
||||
|
||||
std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl;
|
||||
std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl;
|
||||
|
||||
if(workspace_size > 0)
|
||||
{
|
||||
argument->p_workspace_ = workspace_device_buf.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstanceMultiD = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstanceMultiD{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
c_m_n_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { c_element_op(self(idx), self(idx), d0_m_n(idx)); });
|
||||
}
|
||||
|
||||
std::cout << "init method: " << config.init_method << std::endl;
|
||||
std::cout << "KBatch: " << problem_size.KBatch << std::endl;
|
||||
|
||||
float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * problem_size.M * problem_size.N * problem_size.K;
|
||||
std::size_t num_btype = sizeof(ADataType) * problem_size.M * problem_size.K +
|
||||
sizeof(BDataType) * problem_size.K * problem_size.N +
|
||||
sizeof(CDataType) * problem_size.M * problem_size.N +
|
||||
sizeof(D0DataType) * problem_size.M * problem_size.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
double rtol = get_rtol<CDataType>();
|
||||
double atol = get_atol<CDataType>();
|
||||
|
||||
return ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", rtol, atol);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int run_gemm_splitk_multi_d_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config);
|
||||
}
|
||||
@@ -129,5 +129,10 @@ inline bool is_gfx103_supported()
|
||||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
|
||||
}
|
||||
|
||||
inline bool is_wmma_supported()
|
||||
{
|
||||
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,562 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <typeinfo>
|
||||
#include <memory>
|
||||
#include <array>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ReduceDataType = CDataType,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
ReduceDataType,
|
||||
Tuple<>,
|
||||
ReduceDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
PassThrough,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
const ::std::array<const void*, NumDTensor> p_ds_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
const ::std::array<index_t, NumDTensor> stride_ds_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
::std::array<const void*, 0>{},
|
||||
reinterpret_cast<ReduceDataType*>(p_c_grid_),
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
std::array<index_t, 0>{},
|
||||
StrideC_,
|
||||
KBatch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
PassThrough{},
|
||||
true),
|
||||
p_c_grid(p_c_grid_),
|
||||
c_element_op(c_element_op_),
|
||||
p_ds(p_ds_),
|
||||
StrideDs(stride_ds_)
|
||||
{
|
||||
}
|
||||
|
||||
CDataType* p_c_grid;
|
||||
CElementwiseOperation c_element_op;
|
||||
const ::std::array<const void*, NumDTensor> p_ds;
|
||||
::std::array<index_t, NumDTensor> StrideDs;
|
||||
};
|
||||
|
||||
using ReduceAdd = ck::reduce::Add;
|
||||
using OutElementwiseOperation = CElementwiseOperation;
|
||||
|
||||
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(is_same<CLayout, DLayout>::value)
|
||||
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
|
||||
else
|
||||
return Number<1>{};
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
using DeviceReduceInstance = DeviceReduceThreadWiseMultiD<
|
||||
ReduceDataType, // InDataType
|
||||
DsDataType, // DsDatatype
|
||||
GemmAccDataType, // AccDataType
|
||||
CDataType, // OutDataType
|
||||
3, // Rank
|
||||
1, // NumReduceDim
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
OutElementwiseOperation,
|
||||
256, // BlockSize_
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_
|
||||
1, // KThreadSliceSize_
|
||||
0, // InSrcVectorDim_
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
|
||||
decltype(DsVectorLengthSequence)>;
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
static constexpr index_t NumInDim = 3;
|
||||
static constexpr index_t NumOutDim = 2;
|
||||
|
||||
::std::array<index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
|
||||
::std::array<index_t, NumOutDim> out_lengths = {arg.M, arg.N};
|
||||
|
||||
::std::array<index_t, NumInDim> in_strides;
|
||||
::std::array<index_t, NumOutDim> out_strides;
|
||||
if constexpr(is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
in_strides = {arg.M * arg.N, arg.N, 1};
|
||||
out_strides = {arg.N, 1};
|
||||
}
|
||||
else
|
||||
{
|
||||
in_strides = {arg.M * arg.N, 1, arg.M};
|
||||
out_strides = {1, arg.M};
|
||||
}
|
||||
|
||||
::std::array<int, 1> reduce_dims{0};
|
||||
|
||||
::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
|
||||
::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
DsLengths[i] = out_lengths;
|
||||
|
||||
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
DsStrides[i] = {arg.StrideDs[i], 1};
|
||||
}
|
||||
else
|
||||
{
|
||||
DsStrides[i] = {1, arg.StrideDs[i]};
|
||||
}
|
||||
});
|
||||
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
DsLengths,
|
||||
DsStrides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
reduce_dims,
|
||||
arg.p_workspace_,
|
||||
arg.p_ds,
|
||||
arg.p_c_grid,
|
||||
PassThrough{},
|
||||
OutElementwiseOperation{});
|
||||
|
||||
auto invoker_ptr = reduce.MakeInvokerPointer();
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(reduce.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ::std::runtime_error(
|
||||
"The runtime parameters are not supported by the device instance.");
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
|
||||
|
||||
// workspace required when doing two-kernel reduce or Ds present
|
||||
const bool need_workspace = !(!(arg.IsReduceAdd() || NumDTensor > 0) &&
|
||||
is_same<CDataType, ReduceDataType>::value);
|
||||
if(need_workspace)
|
||||
{
|
||||
if(arg.p_workspace_ == nullptr)
|
||||
{
|
||||
throw ::std::runtime_error("using reduce, but empty workspace!");
|
||||
}
|
||||
arg.p_e_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw ::std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
::std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
::ck::kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
::ck::kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
|
||||
}
|
||||
|
||||
if(need_workspace)
|
||||
{
|
||||
ave_time += RunReduce(arg_, stream_config);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_wmma_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return GridwiseGemm::CalculateGridSize(M, N, KBatch);
|
||||
}
|
||||
|
||||
static constexpr index_t GetBlockSize() { return BlockSize; }
|
||||
|
||||
static size_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
return GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const ::std::array<const void*, NumDTensor> p_ds,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const ::std::array<index_t, NumDTensor> stride_ds,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stride_ds,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
::std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return ::std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// Polymorphic interfaces
|
||||
::std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
::std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
::std::array<index_t, NumDTensor> DsStrides,
|
||||
index_t StrideC,
|
||||
index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return ::std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
DsStrides,
|
||||
StrideC,
|
||||
KSplit,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
::std::string GetTypeString() const override
|
||||
{
|
||||
auto str = ::std::stringstream();
|
||||
|
||||
auto BlkGemmPipelineSchedulerToString = [](BlockGemmPipelineScheduler s) {
|
||||
switch(s)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::Intrawave: return ::std::string("Intrawave");
|
||||
case BlockGemmPipelineScheduler::Interwave: return ::std::string("Interwave");
|
||||
}
|
||||
return ::std::string("?");
|
||||
};
|
||||
|
||||
auto BlkGemmPipelineVersionToString = [](BlockGemmPipelineVersion v) {
|
||||
switch(v)
|
||||
{
|
||||
case BlockGemmPipelineVersion::v1: return ::std::string("v1");
|
||||
case BlockGemmPipelineVersion::v2: return ::std::string("v2");
|
||||
case BlockGemmPipelineVersion::v3: return ::std::string("v3");
|
||||
case BlockGemmPipelineVersion::v4: return ::std::string("v4");
|
||||
case BlockGemmPipelineVersion::v5: return ::std::string("v5");
|
||||
}
|
||||
return ::std::string("v?");
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmWmmaUniversalReduce"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< ::std::string(ALayout::name)[0]
|
||||
<< ::std::string(BLayout::name)[0]
|
||||
<< ::std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WmmaTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WmmaRepeat: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString(BlkGemmPipeSched) << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString(BlkGemmPipelineVer) << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = *dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
// Need workspace if using split-K or have D tensors
|
||||
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && is_same<CDataType, ReduceDataType>::value))
|
||||
{
|
||||
return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
@@ -1049,6 +1054,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
{
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
|
||||
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
|
||||
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
@@ -20,6 +21,7 @@ namespace instance {
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
@@ -326,7 +328,54 @@ void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadd
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
|
||||
void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -373,6 +422,7 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
@@ -395,6 +445,12 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -406,6 +462,7 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
@@ -420,6 +477,12 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -430,6 +493,7 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
@@ -444,6 +508,12 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GEMM_UNIVERSAL_REDUCE_INSTANCES)
|
||||
|
||||
# XDL instances
|
||||
list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES
|
||||
device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
@@ -30,4 +31,11 @@ list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
)
|
||||
|
||||
# WMMA instances
|
||||
list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES
|
||||
device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_gemm_universal_reduce_instance ${GEMM_UNIVERSAL_REDUCE_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename DsLayout = ck::Tuple<>,
|
||||
typename DsDataType = ck::Tuple<>>
|
||||
using device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce|
|
||||
//#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType|
|
||||
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | |
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
if(ck::is_gfx12_supported())
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmDefault,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmKPadding,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmMNPadding,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmMNKPadding,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename DsLayout = ck::Tuple<>,
|
||||
typename DsDataType = ck::Tuple<>>
|
||||
using device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce|
|
||||
//#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType|
|
||||
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | |
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 4, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = bhalf_t;
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
if(ck::is_gfx12_supported())
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmDefault,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmKPadding,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmMNPadding,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmMNKPadding,
|
||||
DsLayout,
|
||||
DsDataType>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename DsLayout = ck::Tuple<>,
|
||||
typename DsDataType = ck::Tuple<>>
|
||||
using device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce|
|
||||
//#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType|
|
||||
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | |
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
|
||||
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
using Add = element_wise::Add;
|
||||
|
||||
using DsLayout_F16 = ck::Tuple<>;
|
||||
using DsDataType_F16 = ck::Tuple<>;
|
||||
|
||||
void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
|
||||
Row,
|
||||
DsLayout_F16,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType_F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
if(ck::is_gfx12_supported())
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmDefault,
|
||||
DsLayout_F16,
|
||||
DsDataType_F16>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmKPadding,
|
||||
DsLayout_F16,
|
||||
DsDataType_F16>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmMNPadding>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp"
|
||||
@@ -86,10 +87,21 @@ bool profile_gemm_universal_reduce_impl(int do_verification,
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
@@ -68,7 +68,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
@@ -90,6 +89,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
@@ -185,7 +185,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
@@ -221,6 +220,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
|
||||
@@ -248,6 +248,7 @@ add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_b_scale)
|
||||
add_subdirectory(gemm_universal_streamk)
|
||||
add_subdirectory(gemm_reduce)
|
||||
add_subdirectory(gemm_universal_reduce)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
|
||||
14
test/gemm_universal_reduce/CMakeLists.txt
Normal file
14
test/gemm_universal_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
add_gtest_executable(test_gemm_universal_reduce_bf16_wmma test_gemm_universal_reduce_bf16_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_reduce_bf16_wmma PRIVATE utility device_gemm_universal_reduce_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_reduce_fp16_wmma test_gemm_universal_reduce_fp16_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_reduce_fp16_wmma PRIVATE utility device_gemm_universal_reduce_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_reduce_bf16A_i8_wmma test_gemm_universal_reduce_bf16A_i8_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_reduce_bf16A_i8_wmma PRIVATE utility device_gemm_universal_reduce_instance)
|
||||
endif()
|
||||
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "profiler/profile_gemm_universal_reduce_impl.hpp"
|
||||
|
||||
TEST(GemmUniversalReduce, BF16A_I8)
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
int M = 512;
|
||||
int N = 256;
|
||||
int K = 128;
|
||||
int KBatch = 1;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = pass && ck::profiler::profile_gemm_universal_reduce_impl<ck::bhalf_t,
|
||||
int8_t,
|
||||
ck::Tuple<>,
|
||||
float,
|
||||
ck::bhalf_t,
|
||||
Row,
|
||||
Row,
|
||||
ck::Tuple<>,
|
||||
Row>(
|
||||
true, 3, false, true, M, N, K, K, N, N, KBatch, 1, 10);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "profiler/profile_gemm_universal_reduce_impl.hpp"
|
||||
|
||||
TEST(GemmUniversalReduce, BF16)
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
int M = 512;
|
||||
int N = 256;
|
||||
int K = 128;
|
||||
int KBatch = 1;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = pass && ck::profiler::profile_gemm_universal_reduce_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::Tuple<>,
|
||||
float,
|
||||
ck::bhalf_t,
|
||||
Row,
|
||||
Row,
|
||||
ck::Tuple<>,
|
||||
Row>(
|
||||
true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "profiler/profile_gemm_universal_reduce_impl.hpp"
|
||||
|
||||
TEST(GemmUniversalReduce, FP16)
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
int M = 512;
|
||||
int N = 256;
|
||||
int K = 128;
|
||||
int KBatch = 1;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = pass && ck::profiler::profile_gemm_universal_reduce_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::Tuple<>,
|
||||
float,
|
||||
ck::half_t,
|
||||
Row,
|
||||
Row,
|
||||
ck::Tuple<>,
|
||||
Row>(
|
||||
true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
Reference in New Issue
Block a user