mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Merge branch 'develop' into tianxing/unified-attention
This commit is contained in:
@@ -149,3 +149,7 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale)
|
||||
add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle)
|
||||
add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_bpreshuffle)
|
||||
|
||||
70
example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp
Normal file
70
example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using ComputeTypeA = F16;
|
||||
using ComputeTypeB = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
static constexpr int KPack = 8; // int4 -> 32, fp8 -> 16, fp16 -> 8
|
||||
// clang-format off
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
128,
|
||||
32, 128, 128,
|
||||
8, 8,
|
||||
16, 16,
|
||||
2, 2,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 8>, S<4, 4, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_wmma_bpreshuffle_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
72
example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp
Normal file
72
example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/amd_ck_fp8.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F8;
|
||||
using BDataType = F8;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using ComputeTypeA = F8;
|
||||
using ComputeTypeB = F8;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
static constexpr int KPack = 16; // int4 -> 32, fp8 -> 16, fp16 -> 8
|
||||
// clang-format off
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
32, 128, 256,
|
||||
16, 16,
|
||||
16, 16,
|
||||
2, 1,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_wmma_bpreshuffle_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
206
example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc
Normal file
206
example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc
Normal file
@@ -0,0 +1,206 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{0, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b_k_n_preshuffled: " << b_k_n_preshuffled.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
// weight pre-shuffle
|
||||
int NPerWmma = device_op.GetPreShuffleParameters();
|
||||
int KLane = ck::get_warp_size() / NPerWmma;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NPerWmma
|
||||
// N, K -> N0 K0 KLane NPerWmma KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / NPerWmma;
|
||||
int n1 = n % NPerWmma;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NPerWmma * KLane * K0 + k0 * KPack * NPerWmma * KLane +
|
||||
k1 * KPack * NPerWmma + n1 * KPack + k2;
|
||||
|
||||
b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k);
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto argument =
|
||||
device_op.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 50, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size{3840, 4096, 4096, 4096, 4096, 4096, 1};
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
|
||||
}
|
||||
@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
|
||||
@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
|
||||
@@ -31,7 +31,7 @@ class SimpleAppArgs
|
||||
bool do_verification = true;
|
||||
int data_type = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -31,7 +31,7 @@ class SimpleAppArgs
|
||||
bool do_verification = true;
|
||||
int data_type = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -31,7 +31,7 @@ class SimpleAppArgs
|
||||
bool do_verification = true;
|
||||
int data_type = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -53,7 +53,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = true;
|
||||
init_method = 1;
|
||||
time_kernel = true;
|
||||
time_kernel = false;
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
|
||||
@@ -27,10 +27,11 @@ using ::ck::Tensor;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ActivationOp = PassThrough;
|
||||
@@ -125,11 +126,11 @@ int main(int /* argc */, char* /* argv */[])
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
|
||||
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
@@ -90,7 +90,7 @@ struct ExecutionConfig final
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
int k_batch = 128;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType, DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using DsLayout = ck::Tuple<DLayout, DLayout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr int NumDs = 2;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
@@ -71,339 +71,6 @@ using DeviceGemmInstance =
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
|
||||
@@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=no, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)\n");
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
printf("arg6: k-batch count (default=1)\n");
|
||||
|
||||
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -268,7 +268,7 @@ int main()
|
||||
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
@@ -302,7 +302,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -108,7 +108,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -105,7 +105,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -81,7 +81,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// CGEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -65,7 +65,7 @@ class SimpleAppArgs
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -27,7 +27,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
@@ -69,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_G1_M2_N3_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b,
|
||||
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_gs_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M3_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -353,16 +217,18 @@ int main(int argc, char* argv[])
|
||||
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
|
||||
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchedContraction_G1_M2_N3_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
@@ -399,7 +265,13 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
|
||||
bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::make_ParallelTensorFunctor;
|
||||
@@ -67,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 1 && NumDimM == 3 && NumDimN == 2 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_G1_M3_N2_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a,
|
||||
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_gs_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_G1_M3_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -353,17 +219,18 @@ int main(int argc, char* argv[])
|
||||
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
|
||||
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchedContraction_G1_M3_N2_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
@@ -400,7 +267,13 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
|
||||
bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -3,3 +3,4 @@
|
||||
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_v3_fp16 batched_gemm_bias_e_permute_wmma_v3_fp16.cpp)
|
||||
|
||||
@@ -106,352 +106,5 @@ using DeviceOpInstanceKKNN =
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a,
|
||||
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b,
|
||||
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_G2_M2_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
ck::index_t G0 = 1;
|
||||
ck::index_t G1 = 2;
|
||||
|
||||
ck::index_t M0 = 4;
|
||||
ck::index_t M1 = 128;
|
||||
|
||||
ck::index_t N0 = 16;
|
||||
ck::index_t N1 = 256;
|
||||
|
||||
ck::index_t K0 = 2048;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
G0 = std::stoi(argv[4]);
|
||||
G1 = std::stoi(argv[5]);
|
||||
M0 = std::stoi(argv[6]);
|
||||
M1 = std::stoi(argv[7]);
|
||||
N0 = std::stoi(argv[8]);
|
||||
N1 = std::stoi(argv[9]);
|
||||
K0 = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// A[G0, G1, M0, M1, K0]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
|
||||
// B[G0, G1, N0, N1, K0]
|
||||
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
|
||||
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
|
||||
|
||||
// D[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
|
||||
// E[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> e_gs_ms_ns_strides{
|
||||
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{});
|
||||
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{});
|
||||
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{});
|
||||
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
|
||||
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) *
|
||||
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
|
||||
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// device operation
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
|
||||
e_gs_ms_ns_lengths,
|
||||
e_gs_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t G =
|
||||
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t M = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
|
||||
std::size_t flop = std::size_t(2) * G * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
|
||||
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(
|
||||
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
|
||||
{
|
||||
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
|
||||
{
|
||||
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
|
||||
++n1)
|
||||
{
|
||||
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#include "run_batched_gemm_bias_e_permute_example.inc"
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::make_ParallelTensorFunctor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using DeviceOpInstanceKKNN =
|
||||
ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3<
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
ASpec,
|
||||
BSpec,
|
||||
DESpec,
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
4,
|
||||
4,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
4,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
false,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
false,
|
||||
1,
|
||||
1,
|
||||
S<1, 64, 1, 2>,
|
||||
S<8, 8>>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
#include "run_batched_gemm_bias_e_permute_example.inc"
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); }
|
||||
@@ -67,340 +67,5 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a,
|
||||
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b,
|
||||
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_G2_M2_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t G0 = 1;
|
||||
ck::index_t G1 = 2;
|
||||
|
||||
ck::index_t M0 = 4;
|
||||
ck::index_t M1 = 256;
|
||||
|
||||
ck::index_t N0 = 16;
|
||||
ck::index_t N1 = 128;
|
||||
|
||||
ck::index_t K0 = 64;
|
||||
|
||||
// A[G0, G1, M0, M1, K0]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
|
||||
// B[G0, G1, N0, N1, K0]
|
||||
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
|
||||
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
|
||||
|
||||
// D[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
|
||||
// E[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> e_gs_ms_ns_strides{
|
||||
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{});
|
||||
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{});
|
||||
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{});
|
||||
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
|
||||
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) *
|
||||
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
|
||||
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// device operation
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
|
||||
e_gs_ms_ns_lengths,
|
||||
e_gs_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t G =
|
||||
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t M = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
|
||||
std::size_t flop = std::size_t(2) * G * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
|
||||
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(
|
||||
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
|
||||
{
|
||||
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
|
||||
{
|
||||
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
|
||||
++n1)
|
||||
{
|
||||
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#include "run_batched_gemm_bias_e_permute_example.inc"
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a,
|
||||
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b,
|
||||
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_G2_M2_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int run_batched_gemm_bias_e_permute_example(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t G0 = 1;
|
||||
ck::index_t G1 = 2;
|
||||
|
||||
ck::index_t M0 = 4;
|
||||
ck::index_t M1 = 128;
|
||||
|
||||
ck::index_t N0 = 16;
|
||||
ck::index_t N1 = 256;
|
||||
|
||||
ck::index_t K0 = 2048;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
G0 = std::stoi(argv[4]);
|
||||
G1 = std::stoi(argv[5]);
|
||||
M0 = std::stoi(argv[6]);
|
||||
M1 = std::stoi(argv[7]);
|
||||
N0 = std::stoi(argv[8]);
|
||||
N1 = std::stoi(argv[9]);
|
||||
K0 = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// A[G0, G1, M0, M1, K0]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
|
||||
// B[G0, G1, N0, N1, K0]
|
||||
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
|
||||
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
|
||||
|
||||
// D[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
|
||||
// E[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> e_gs_ms_ns_strides{
|
||||
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{});
|
||||
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{});
|
||||
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{});
|
||||
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
|
||||
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) *
|
||||
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
|
||||
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// device operation
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
|
||||
e_gs_ms_ns_lengths,
|
||||
e_gs_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t G =
|
||||
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t M = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
|
||||
std::size_t flop = std::size_t(2) * G * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
|
||||
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(
|
||||
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
|
||||
{
|
||||
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
|
||||
{
|
||||
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
|
||||
++n1)
|
||||
{
|
||||
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -92,7 +92,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParam \
|
||||
|
||||
@@ -92,7 +92,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParam \
|
||||
|
||||
@@ -40,7 +40,7 @@ class SimpleAppArgs
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
SimpleAppArgs()
|
||||
|
||||
@@ -44,7 +44,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
template <ck::index_t... Is>
|
||||
|
||||
@@ -56,7 +56,7 @@ template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInsta
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t num_rows = 65536;
|
||||
constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{};
|
||||
|
||||
@@ -195,7 +195,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -86,7 +86,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -87,7 +87,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -90,7 +90,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
|
||||
ck::index_t C = 128;
|
||||
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
bool log_kernel = true;
|
||||
|
||||
if(argc == 1)
|
||||
|
||||
@@ -53,7 +53,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -49,7 +49,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -121,7 +121,7 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input,
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
const float scale = 2.f;
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ void host_elementwise2D(HostTensorC& C,
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t M = 48 * 256;
|
||||
ck::index_t N = 1024;
|
||||
|
||||
@@ -31,8 +31,9 @@ using S = ck::Sequence<Is...>;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
@@ -139,11 +140,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
|
||||
@@ -193,7 +193,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
|
||||
@@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an
|
||||
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
|
||||
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
@@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
|
||||
#else
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
@@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -182,12 +184,14 @@ int main(int argc, char* argv[])
|
||||
bool time_kernel = true;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t N = 1536;
|
||||
ck::index_t K = 4096;
|
||||
// ck::index_t N = 4096;
|
||||
// ck::index_t K = 6144;
|
||||
// ck::index_t N = 128;
|
||||
// ck::index_t K = 512;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t topk = 2;
|
||||
ck::index_t experts = 16;
|
||||
ck::index_t topk = 8;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 208;
|
||||
@@ -196,9 +200,9 @@ int main(int argc, char* argv[])
|
||||
// ck::index_t sorted_tile_num = 259;
|
||||
// ck::index_t valid_tile_num = 256;
|
||||
// ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 2;
|
||||
ck::index_t valid_tile_num = 2;
|
||||
ck::index_t tokens = 32;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 16;
|
||||
ck::index_t tokens = 4;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
@@ -209,7 +213,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t sorted_tile_num = 261;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
#endif
|
||||
ck::index_t KBatch = 6;
|
||||
ck::index_t KBatch = 1;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
|
||||
@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -185,7 +185,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// tokens = 1
|
||||
// topk = 1
|
||||
|
||||
@@ -164,7 +164,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
@@ -64,11 +66,11 @@ int run_gemm_example(int argc, char* argv[])
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return ck::HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -178,7 +178,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -208,7 +208,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -204,7 +204,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
else if(argc == 11)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
@@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
@@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
else if(argc == 11)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
@@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
@@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,35 @@ include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
# On Windows, HIP uses -nostdlib which prevents C runtime linking
|
||||
# We need legacy_stdio_definitions.lib to provide vfprintf and other legacy C functions
|
||||
# This is mainly needed for the getopt library.
|
||||
set(LEGACY_STDIO_SEARCH_PATHS)
|
||||
|
||||
# Try to use Visual C++ Tools environment variable (if build executes from Visual Studio Developer Command Prompt)
|
||||
if(DEFINED ENV{VCToolsInstallDir})
|
||||
list(APPEND LEGACY_STDIO_SEARCH_PATHS "$ENV{VCToolsInstallDir}/lib/x64")
|
||||
endif()
|
||||
|
||||
# Fallback: Search common Visual Studio installation locations
|
||||
file(GLOB MSVC_LIB_PATHS "C:/Program Files/Microsoft Visual Studio/*/*/VC/Tools/MSVC/*/lib/x64")
|
||||
list(APPEND LEGACY_STDIO_SEARCH_PATHS ${MSVC_LIB_PATHS})
|
||||
|
||||
# Use find_library to locate the library
|
||||
find_library(LEGACY_STDIO_LIB legacy_stdio_definitions
|
||||
PATHS ${LEGACY_STDIO_SEARCH_PATHS}
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
if(LEGACY_STDIO_LIB)
|
||||
message(STATUS "Found legacy_stdio_definitions.lib: ${LEGACY_STDIO_LIB}")
|
||||
add_link_options("SHELL:-Xlinker \"${LEGACY_STDIO_LIB}\"")
|
||||
else()
|
||||
message(WARNING "Could not find legacy_stdio_definitions.lib - examples may fail to link.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
|
||||
@@ -216,6 +245,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
|
||||
@@ -36,6 +36,19 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
{F_page_size},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
|
||||
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
trait.kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_lookup_table
|
||||
],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -604,23 +647,42 @@ class KernelComponentFactory:
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
for (
|
||||
logits,
|
||||
mask,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -672,69 +734,75 @@ def get_fwd_blobs(
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ class FmhaFwdApiTrait:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
@@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
|
||||
and kernel_ctx.tile.F_bm0 != 128
|
||||
)
|
||||
or (
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
|
||||
and kernel_ctx.pipeline.tag != "qr_async"
|
||||
and kernel_ctx.tile.F_bk0 == 64
|
||||
)
|
||||
):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
@@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
|
||||
@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int init_sink_value = arg_parser.get_int("init_sink");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
init_sink_value,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
@@ -230,6 +230,7 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -519,6 +523,7 @@ struct fmha_batch_prefill_args
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -529,14 +534,25 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
// SGLang-style page table
|
||||
int32_t num_total_pages;
|
||||
void* kv_indptr;
|
||||
void* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
void* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#endif
|
||||
// KV cache page table fields (kv_lookup_table selects interpretation):
|
||||
// - SGLANG_PAGE_TABLE_1D:
|
||||
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
|
||||
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
|
||||
// kv_last_page_lens: per-batch last page lengths [batch]
|
||||
// - VLLM_BLOCK_TABLE_2D:
|
||||
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
|
||||
// batch_stride_block_table: row stride for block_table
|
||||
// seqlen_k_ptr: per-batch seqlen_k [batch]
|
||||
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
|
||||
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
|
||||
kv_memory_layout; // KV memory layout (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
|
||||
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
|
||||
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
|
||||
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
|
||||
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
|
||||
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
@@ -627,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -677,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -837,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
args.min_seqlen_q,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -882,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -949,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -997,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1113,6 +1135,22 @@ template <typename FmhaKernel>
|
||||
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
|
||||
const PageTableKargs page_table = [&]() {
|
||||
if constexpr(FmhaKernel::kKVLookupTable ==
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
|
||||
}
|
||||
else
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
args.batch_stride_block_table,
|
||||
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
|
||||
}
|
||||
}();
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
@@ -1133,12 +1171,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1164,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1184,12 +1219,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1220,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1281,6 +1313,65 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
ck_tile::index_t kPageBlockSize_ = 1,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
BiasEnum_,
|
||||
kStoreLse_,
|
||||
kHasDropout_,
|
||||
QScaleEnum_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
@@ -1527,7 +1618,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
using fmha_batch_prefill_traits = fmha_fwd_traits;
|
||||
struct fmha_batch_prefill_traits : public fmha_fwd_traits
|
||||
{
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
int page_size = 1;
|
||||
};
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits,
|
||||
fmha_batch_prefill_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename SMPLComputeDataType>
|
||||
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
|
||||
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
|
||||
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t real_seqlen_q,
|
||||
ck_tile::index_t real_seqlen_k)
|
||||
{
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
|
||||
}
|
||||
// Append sink token at the end of each row
|
||||
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
int do_validation,
|
||||
int init_sink_value,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
@@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
@@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
||||
bias_host);
|
||||
}
|
||||
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
|
||||
// for close to rowmax values.
|
||||
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
|
||||
sink_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
@@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
@@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
if(init_sink_value != 0)
|
||||
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
||||
else
|
||||
args.sink_ptr = nullptr;
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
||||
args.hdim_q = hdim_q;
|
||||
@@ -1351,8 +1387,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
else if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_o_host};
|
||||
else
|
||||
@@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
||||
if(lse)
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
// Create extended tensor with sink token
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
// Copy original attention scores and append sink values
|
||||
copy_attention_scores_with_sink(
|
||||
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
|
||||
|
||||
// Compute softmax on extended tensor
|
||||
ck_tile::HostTensor<PDataType> p_extended(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func);
|
||||
}
|
||||
|
||||
// Extract only the original columns (exclude sink token column)
|
||||
p_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
// No sink tokens - compute softmax directly
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
|
||||
@@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
@@ -69,107 +69,88 @@ struct BasicInvoker
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize =
|
||||
(total_elements + elements_per_block - 1) / elements_per_block;
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
args.stride_E);
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&]() {
|
||||
// use SET operation since each K-split writes to separate memory
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
return Run();
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -460,12 +460,6 @@ inline auto create_args()
|
||||
return arg_parser;
|
||||
}
|
||||
|
||||
// Type aliases for memory operation integral constants
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -57,114 +57,95 @@ struct WeightPreshuffleInvoker
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time =
|
||||
ck_tile::launch_kernel_time_mask(s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -60,112 +60,94 @@ struct UniversalInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Threadwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Blockwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "19", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = float;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
// Operations: one doing a sum reduction, the other computing the mean square
|
||||
// In the case of mean square:
|
||||
// 1. The element wise operation squares each element before reduction
|
||||
// 2. The reduction operation sum the squared element
|
||||
// 3. The accumulator element wise operation divides the result by the total number of reduced
|
||||
// elements (intra block operation)
|
||||
// 4. The partial result is updated across blocks using inter block reduction, a sum.
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions
|
||||
auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise
|
||||
// ops
|
||||
auto accumulator_elementwise_ops = ck_tile::make_tuple(
|
||||
ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block
|
||||
auto inter_block_reduce_ops = ck_tile::make_tuple(
|
||||
ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
|
||||
|
||||
// Determine block group size for multi-block reduction
|
||||
// block_group_size records how many blocks participate to a reduction (input data dependent)
|
||||
// , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient
|
||||
// to process the whole reduction, each thread will to process multiple thread tile
|
||||
// a num_block_tile_iterations times
|
||||
auto [num_block_tile_iterations, block_group_size] =
|
||||
typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams();
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
ck_tile::index_t kGridSize =
|
||||
((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size;
|
||||
|
||||
std::cout << "Block group size: " << block_group_size
|
||||
<< ", Num block tile iterations: " << num_block_tile_iterations
|
||||
<< ", Reduce total length: " << reduce_total_length << std::endl;
|
||||
std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl;
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
// Init the output data with identity values respective to each reduce op
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
constexpr auto op = reduce_ops.at(i);
|
||||
const auto identity_val = op.template GetIdentityValue<YDataType>();
|
||||
const auto output_number_elements = N * C;
|
||||
std::fill(h.begin() + i * output_number_elements,
|
||||
h.begin() + (i + 1) * output_number_elements,
|
||||
identity_val);
|
||||
});
|
||||
|
||||
auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); };
|
||||
|
||||
float ave_time = launch_kernel_time_mask(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
clear_output_buffer,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops)
|
||||
|
||||
);
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops,
|
||||
block_group_size);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
std::cout << "Checking operation " << i << ": " << std::endl;
|
||||
|
||||
bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
|
||||
if(pass_op)
|
||||
{
|
||||
std::cout << "✅ valid results for this operation" << std::endl;
|
||||
}
|
||||
pass &= pass_op;
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "7", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
// Two operations: one do a sum reduction, the other computing the mean square
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops
|
||||
auto elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise ops
|
||||
auto accumulator_elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementiwise ops on reduction,
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceThreadWise<Problem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user