mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8
This commit is contained in:
@@ -77,6 +77,9 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
|
||||
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8_streamk_v3 gemm_xdl_fp8_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_streamk_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ struct ProblemSizeStreamK final
|
||||
ck::index_t StrideB = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
ck::index_t NumSKBlocks = -1; // number of stream-k blocks
|
||||
};
|
||||
struct ProblemSizeStreamK_universal final
|
||||
{
|
||||
@@ -76,7 +76,7 @@ struct ProblemSizeSplitK final
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
|
||||
int do_verification = 3;
|
||||
int do_verification = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
@@ -43,6 +43,17 @@ using DeviceGemmV2_Streamk_Instance =
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example_streamk_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); }
|
||||
|
||||
58
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
Executable file
58
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
Executable file
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2_Streamk_Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
128, 256,
|
||||
128, 16, 16,
|
||||
16, 16,
|
||||
4, 8,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 1,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example_streamk_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); }
|
||||
@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
|
||||
break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
|
||||
40
example/01_gemm/run_gemm_example_streamk_v2.inc
Normal file → Executable file
40
example/01_gemm/run_gemm_example_streamk_v2.inc
Normal file → Executable file
@@ -176,6 +176,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
@@ -196,6 +197,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) *
|
||||
c_m_n_device_ref_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
@@ -240,6 +243,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
return true;
|
||||
}
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
if(workspace_size != 0)
|
||||
{
|
||||
workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if((config.do_verification == 1) || (config.do_verification == 3))
|
||||
{
|
||||
@@ -271,6 +281,36 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
#endif
|
||||
}
|
||||
|
||||
if((config.do_verification == 2) || (config.do_verification == 3))
|
||||
{
|
||||
// GPU verification
|
||||
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
|
||||
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
|
||||
|
||||
auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
std::cout << "Running verification on GPU." << std::endl;
|
||||
ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{});
|
||||
|
||||
c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data());
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_device_ref_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
@@ -261,7 +261,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4});
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
|
||||
@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
gemm.SetKBatchSize(argument, config.k_batch);
|
||||
gemm.SetKBatchSize(&argument, config.k_batch);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
@@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDs>;
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<D0DataType, 1>{});
|
||||
}
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;
|
||||
@@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
gemm.SetDeviceKernelArgs(&argument, gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,9 +171,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
|
||||
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
|
||||
DeviceMem gemm_workspace, gemm_kargs;
|
||||
|
||||
// The following is necessary since TwoStage kernel is using additional memory both
|
||||
// for Workspace and kernel arguments.
|
||||
if(kargs_size > 0)
|
||||
{
|
||||
gemm_kargs.Realloc(kargs_size);
|
||||
gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer());
|
||||
}
|
||||
if(workspace_size > 0 && workspace_size != kargs_size)
|
||||
{
|
||||
gemm_workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -175,8 +175,8 @@ int main(int argc, char* argv[])
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
|
||||
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -118,7 +118,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -153,7 +153,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -178,7 +178,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -152,7 +152,7 @@ int run(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -173,7 +173,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -377,7 +377,7 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -41,7 +41,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParams \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B1DataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
|
||||
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A1DataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
|
||||
@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
// values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 6});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
|
||||
a1_device_buf.ToDevice(a1_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
|
||||
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any DPP examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp")
|
||||
message("removing dpp example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
|
||||
2
example/README.md
Normal file
2
example/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
[Back to the main page](../README.md)
|
||||
# Composable Kernel examples
|
||||
@@ -2,10 +2,17 @@
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp8" : "ck_tile::fp8_t"
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
"fp16": "FmhaBwdFp16",
|
||||
"bf16": "FmhaBwdBf16"
|
||||
}
|
||||
|
||||
MASK_IMPL = {
|
||||
|
||||
@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
gen = list()
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_bm0,
|
||||
F_bn0 = self.F_bn0,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
|
||||
@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
@@ -301,7 +301,7 @@ class FmhaFwdTileSize:
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
@@ -339,7 +339,7 @@ class FmhaFwdKernel:
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
|
||||
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bs = self.F_tile.F_bs,
|
||||
F_bsk = self.F_tile.F_bsk,
|
||||
F_bd = self.F_tile.F_bd,
|
||||
@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# rope/paged-kv is not supported
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
gen = list()
|
||||
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
|
||||
@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
{F_hdim},
|
||||
{F_bm0},
|
||||
{F_bm0},
|
||||
{F_bn1},
|
||||
{F_mode},
|
||||
fmha_trait>;
|
||||
@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
|
||||
if(s.log_level_ > 0)
|
||||
std::cout
|
||||
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
if (t.has_lse) {{
|
||||
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
|
||||
return -1;
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -421,11 +431,11 @@ class FmhaFwdSplitKVApiPool:
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
@@ -462,7 +472,7 @@ class FmhaFwdSplitKVKernel:
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
@@ -482,7 +492,7 @@ class FmhaFwdSplitKVKernel:
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
@@ -542,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
@@ -614,27 +624,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/paged-kv kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -642,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
gen = list()
|
||||
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
@@ -655,9 +667,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if pipeline.F_pagedkv == 't':
|
||||
# we only use batch mode kernels to handle (paged-) kvcache problems
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -705,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
|
||||
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataType>;
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -14,11 +14,19 @@
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaBwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<ck_tile::half_t>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ref/naive_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "rotary.hpp"
|
||||
#include "utils.hpp"
|
||||
@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[])
|
||||
"-1 to choose s_knew in [1, s] randomly.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed. same as xformer kv_padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(seqlen_knew != 0)
|
||||
{
|
||||
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
|
||||
std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
|
||||
<< std::endl;
|
||||
seqlen_knew = 0;
|
||||
}
|
||||
#endif
|
||||
@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
|
||||
std::is_same_v<DataType, ck_tile::bf16_t>))
|
||||
if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
|
||||
{
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rotary_dim = 0;
|
||||
}
|
||||
#endif
|
||||
// to use fmha_fwd_appendkv(), make sure it's in batch mode
|
||||
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
if(need_append_kvcache && mode == mode_enum::group)
|
||||
{
|
||||
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
#else
|
||||
if(use_cache_batch_idx)
|
||||
{
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
else if(mode == mode_enum::group)
|
||||
{
|
||||
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if(0 < page_block_size && use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
// the input tensor layout for kvcache is same as batch mode
|
||||
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
|
||||
if(use_kvcache && mode != mode_enum::batch)
|
||||
{
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
|
||||
auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
|
||||
decode_seqlen(mode,
|
||||
@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_str("s_k"),
|
||||
arg_parser.get_str("s_kpad"),
|
||||
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
||||
use_kvcache);
|
||||
need_append_kvcache);
|
||||
// compute kvcache seqlen_k (before appending knew/vnew)
|
||||
auto cache_seqlen_ks = seqlen_ks;
|
||||
std::transform(cache_seqlen_ks.begin(),
|
||||
@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return atoi(squant_str.c_str()) != 0 ? true : false;
|
||||
}();
|
||||
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
|
||||
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
|
||||
scale_p = dtype_max / range_p;
|
||||
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
|
||||
scale_o = range_p * range_v / range_o / dtype_max;
|
||||
}
|
||||
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
|
||||
@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
bool s_randval = false;
|
||||
if(p_drop > 0.0f && do_validation)
|
||||
if(p_drop > 0.0f && do_validation != 0)
|
||||
{
|
||||
s_randval = true;
|
||||
}
|
||||
@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataType>;
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
|
||||
scale_p = p_dtype_max / range_p;
|
||||
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
|
||||
}
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else if(init_method == "ufq" || init_method == "uf:q" ||
|
||||
init_method == "3") // suitable for fp8 quantization
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
|
||||
|
||||
// bias_fp8 = qscale_bias * bias_fp32
|
||||
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
|
||||
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
|
||||
// Assume bias is in [-1.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
|
||||
}
|
||||
@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(
|
||||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
||||
0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
ck_tile::DeviceMem cache_seqlen_k_buf(
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
args.is_gappy = false; // use 'false' for flash-attention integration
|
||||
|
||||
args.cache_batch_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
@@ -1100,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
|
||||
if(!do_validation)
|
||||
if(do_validation == 0)
|
||||
{
|
||||
std::cout << std::flush << std::endl;
|
||||
return true;
|
||||
}
|
||||
if(do_validation == 2)
|
||||
{
|
||||
// NOTE: use gpu to do validation
|
||||
ck_tile::naive_attention_fwd_traits naive_t;
|
||||
naive_t.q_type = data_type;
|
||||
naive_t.k_type = data_type;
|
||||
naive_t.v_type = data_type;
|
||||
naive_t.o_type = data_type;
|
||||
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.variation = 0; // TODO?
|
||||
|
||||
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::naive_attention_fwd_args naive_a;
|
||||
naive_a.q_ptr = q_buf.GetDeviceBuffer();
|
||||
naive_a.k_ptr = k_buf.GetDeviceBuffer();
|
||||
naive_a.v_ptr = v_buf.GetDeviceBuffer();
|
||||
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
|
||||
naive_a.scale_s = scale_s;
|
||||
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
|
||||
naive_a.page_table_ptr =
|
||||
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
|
||||
naive_a.hdim = hdim_q;
|
||||
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
|
||||
naive_a.batch_q = batch;
|
||||
naive_a.batch_kv = batch;
|
||||
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
|
||||
naive_a.seqlen_q = seqlen_qs[0];
|
||||
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
|
||||
naive_a.nhead_q = nhead;
|
||||
naive_a.nhead_kv = nhead_k;
|
||||
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
|
||||
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
|
||||
|
||||
ck_tile::stream_config naive_s{};
|
||||
|
||||
naive_attention_fwd(naive_t, naive_a, naive_s);
|
||||
|
||||
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
||||
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
||||
|
||||
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool pass_ = ck_tile::check_err(
|
||||
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
||||
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
|
||||
return pass_;
|
||||
}
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::scales{scale_p};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
@@ -1168,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding(
|
||||
@@ -1184,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
|
||||
});
|
||||
} else {
|
||||
} else {
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
|
||||
});
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
#endif
|
||||
{
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
@@ -1211,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
|
||||
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding(
|
||||
@@ -1233,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(0 < page_block_size) {
|
||||
if(is_v_rowmajor) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
|
||||
});
|
||||
} else {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
else
|
||||
{
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
|
||||
});
|
||||
} else {
|
||||
@@ -1440,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
pass &= cur_pass;
|
||||
@@ -1497,15 +1569,15 @@ int main(int argc, char* argv[])
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -16,11 +16,35 @@
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Fp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
{
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
@@ -165,6 +189,8 @@ struct fmha_fwd_splitkv_args
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
|
||||
// nullptr.
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
@@ -173,9 +199,21 @@ struct fmha_fwd_splitkv_args
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// kvcache mode (use same kernel as batch mode):
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
//
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
//
|
||||
// when is_gappy=true:
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// seqstart_k_ptr[b] now store local offset of each batch
|
||||
//
|
||||
// when is_gappy=false:
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
@@ -251,7 +289,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx;
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -278,87 +316,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -389,6 +427,10 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
|
||||
@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min = 0,
|
||||
bool use_kvcache = false,
|
||||
bool need_append_kvcache = false,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && use_kvcache)
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp)
|
||||
add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
|
||||
@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
|
||||
@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Gemm{MemBoundPipeline}"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
|
||||
// TODO: add different init types
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
|
||||
@@ -202,14 +199,16 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
|
||||
// work.
|
||||
// else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(a_layout == "C" && b_layout == "R")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
|
||||
@@ -14,12 +14,34 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
|
||||
#endif
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// ToDo: This will be modified by the codegen code later.
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = true;
|
||||
constexpr bool kPadN = true;
|
||||
constexpr bool kPadK = true;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -49,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
|
||||
#endif
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
|
||||
#endif
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
Traits,
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
ck_tile::GemmPipelineScheduler::Interwave,
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
#endif
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
@@ -86,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
@@ -174,8 +212,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "When there's no hot loop, this tail number \"" << tail_num
|
||||
<< "\" is not supported! " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
|
||||
@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
|
||||
{
|
||||
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
|
||||
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
|
||||
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
};
|
||||
|
||||
// assume this is B matrix, originally we have batch*n*k
|
||||
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
|
||||
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
|
||||
constexpr index_t Kv = Alignment;
|
||||
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return tmp_1;
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{i_n * NPerBlock, i_k * KPerBlock},
|
||||
get_dst_dist());
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
|
||||
@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
|
||||
{
|
||||
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
|
||||
endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
|
||||
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
|
||||
add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp)
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
|
||||
|
||||
@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using ComputeDataType = float;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
x_stride,
|
||||
y_stride};
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
|
||||
@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
using YDataType = ComputeDataType;
|
||||
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto n_) {
|
||||
@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
qy_buf.FromDevice(qy_host_dev.data());
|
||||
auto [rtol, atol] = get_elimit<QYDataType>();
|
||||
|
||||
if(stride == n)
|
||||
if(y_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(qy_host_dev,
|
||||
qy_host_ref,
|
||||
@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
|
||||
qy_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
|
||||
qy_host_ref.begin() + i_r * stride + n);
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
|
||||
qy_host_dev.begin() + i_r * y_stride +
|
||||
n);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
|
||||
qy_host_ref.begin() + i_r * y_stride +
|
||||
n);
|
||||
pass &= ck_tile::check_err(qy_host_dev_row,
|
||||
qy_host_ref_row,
|
||||
std::string("qy[") + std::to_string(i_r) +
|
||||
@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= n);
|
||||
assert(x_stride >= n);
|
||||
|
||||
using TypeConfig = SmoothquantTypeConfig<DataType>;
|
||||
|
||||
@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
<< std::flush;
|
||||
|
||||
smoothquant_traits traits{data_type};
|
||||
|
||||
@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
x_stride,
|
||||
y_stride};
|
||||
|
||||
float ave_time = smoothquant(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
using YDataType = ComputeDataType;
|
||||
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto n_) {
|
||||
@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
qy_buf.FromDevice(qy_host_dev.data());
|
||||
auto [rtol, atol] = get_elimit<QYDataType>();
|
||||
|
||||
if(stride == n)
|
||||
if(y_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(qy_host_dev,
|
||||
qy_host_ref,
|
||||
@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
|
||||
qy_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
|
||||
qy_host_ref.begin() + i_r * stride + n);
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
|
||||
qy_host_dev.begin() + i_r * y_stride +
|
||||
n);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
|
||||
qy_host_ref.begin() + i_r * y_stride +
|
||||
n);
|
||||
pass &= ck_tile::check_err(qy_host_dev_row,
|
||||
qy_host_ref_row,
|
||||
std::string("qy[") + std::to_string(i_r) +
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/moe_sorting.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
|
||||
25
example/ck_tile/14_moe_smoothquant/CMakeLists.txt
Normal file
25
example/ck_tile/14_moe_smoothquant/CMakeLists.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
list(APPEND INSTANCE_SRCS ${source})
|
||||
endforeach()
|
||||
|
||||
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
|
||||
endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
|
||||
add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})
|
||||
|
||||
15
example/ck_tile/14_moe_smoothquant/README.md
Normal file
15
example/ck_tile/14_moe_smoothquant/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# moe-smoothquant
|
||||
|
||||
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
|
||||

|
||||
|
||||
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_smoothquant -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_smoothquant`
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ,false>>(const S&, A);
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8,true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4,true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2,true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1,true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "moe_smoothquant.hpp"
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = moe_smoothquant_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename data_type>
|
||||
float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
|
||||
moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
if(a.hidden_size <= 64) {
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 128) {
|
||||
if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 256) {
|
||||
if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 512) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 768) {
|
||||
if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 1024) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 1536) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 2048) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 3072) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 4096) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size > 4096) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits t,
|
||||
moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::fp16_t>(t, a, s);
|
||||
}
|
||||
else if(t.data_type.compare("bf16") == 0)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::bf16_t>(t, a, s);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = moe_smoothquant_args;
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = moe_smoothquant_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float moe_smoothquant_(const S& s, A a)
|
||||
{
|
||||
using DataType = typename Traits_::DataType;
|
||||
|
||||
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
|
||||
typename MoeSmoothquantTypeConfig<DataType>::XDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::XScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::QYDataType,
|
||||
typename Traits_::Shape,
|
||||
Traits_::kPadN,
|
||||
Traits_::kTwoPass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Kernel = ck_tile::MoeSmoothquant<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
BIN
example/ck_tile/14_moe_smoothquant/misc/moe-sm.png
Normal file
BIN
example/ck_tile/14_moe_smoothquant/misc/moe-sm.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 202 KiB |
264
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
Normal file
264
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
Normal file
@@ -0,0 +1,264 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-5;
|
||||
double atol = 1e-5;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-5;
|
||||
double atol = 1e-5;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
// due to rounding, int8 quantization might have 1 abs error
|
||||
double rtol = 1;
|
||||
double atol = 1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "3328", "tokens dimension")
|
||||
.insert("h", "4096", "hidden_size dimension")
|
||||
.insert("e", "32", "experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= hidden_size);
|
||||
|
||||
using TypeConfig = MoeSmoothquantTypeConfig<DataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using XScaleDataType = typename TypeConfig::XScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({experts * hidden_size});
|
||||
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({topk * tokens}, {1});
|
||||
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({topk * tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({topk * tokens, hidden_size}, {stride, 1});
|
||||
|
||||
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
topk_ids_buf.ToDevice(topk_ids_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
|
||||
<< ", experts:" << experts << ", topk:" << topk << std::flush;
|
||||
|
||||
moe_smoothquant_traits traits{data_type};
|
||||
|
||||
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
hidden_size,
|
||||
experts,
|
||||
topk,
|
||||
stride,
|
||||
stride};
|
||||
|
||||
float ave_time = moe_smoothquant(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * tokens * hidden_size + sizeof(XScaleDataType) * topk * hidden_size +
|
||||
sizeof(YScaleDataType) * topk * tokens + sizeof(QYDataType) * topk * tokens * hidden_size;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
using YDataType = ComputeDataType;
|
||||
ck_tile::HostTensor<ComputeDataType> y_host({topk * tokens, hidden_size}, {stride, 1});
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto i_token) {
|
||||
for(int i_topk = 0; i_topk < topk; i_topk++)
|
||||
{
|
||||
auto i_expert = topk_ids_host(i_token, i_topk);
|
||||
|
||||
for(int i_h = 0; i_h < hidden_size; ++i_h)
|
||||
{
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(
|
||||
xscale_host(i_expert * hidden_size + i_h));
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
|
||||
// y_host(i_token * topk + i_topk, i_h) = v_x * v_xscale;
|
||||
y_host(i_topk * tokens + i_token, i_h) = v_x * v_xscale;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, tokens)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// yscale
|
||||
{
|
||||
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({topk * tokens});
|
||||
|
||||
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
|
||||
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
|
||||
y_host, y_rowwise_amax_host, ReduceAmax{});
|
||||
|
||||
auto op = [](const auto& v0) {
|
||||
return v0 /
|
||||
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
|
||||
};
|
||||
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
|
||||
y_rowwise_amax_host, yscale_host_ref, op);
|
||||
|
||||
yscale_buf.FromDevice(yscale_host_dev.mData.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<YScaleDataType>();
|
||||
pass &= ck_tile::check_err(yscale_host_dev,
|
||||
yscale_host_ref,
|
||||
std::string("yscale Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
// rowwise quantization
|
||||
{
|
||||
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
|
||||
y_host, yscale_host_ref, qy_host_ref);
|
||||
|
||||
qy_buf.FromDevice(qy_host_dev.data());
|
||||
auto [rtol, atol] = get_elimit<QYDataType>();
|
||||
|
||||
if(stride == hidden_size)
|
||||
{
|
||||
pass = ck_tile::check_err(qy_host_dev,
|
||||
qy_host_ref,
|
||||
std::string("qy Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < topk * tokens; i_r++)
|
||||
{
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
|
||||
qy_host_dev.begin() + i_r * stride +
|
||||
hidden_size);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
|
||||
qy_host_ref.begin() + i_r * stride +
|
||||
hidden_size);
|
||||
pass &= ck_tile::check_err(qy_host_dev_row,
|
||||
qy_host_ref_row,
|
||||
std::string("qy[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
114
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
Normal file
114
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/smoothquant.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
struct MoeSmoothquantTypeConfig;
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
struct moe_smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct moe_smoothquant_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
|
||||
37
example/ck_tile/14_moe_smoothquant/script/perf_test.sh
Executable file
37
example/ck_tile/14_moe_smoothquant/script/perf_test.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
|
||||
EXE=build/bin/tile_example_moe_smoothquant
|
||||
|
||||
$EXE -t=1 -h=1 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=80 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=128 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=144 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=168 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=184 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=256 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=288 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=344 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=376 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=448 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=512 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=924 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=1024 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=1078 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=1996 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=4080 -v=1 -prec=bf16 -repeat=1000
|
||||
|
||||
$EXE -t=700 -h=80 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=128 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=144 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=168 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=184 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=256 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=288 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=344 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=376 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=448 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=512 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=924 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=1024 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=1078 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=1996 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=4080 -v=1 -prec=fp16 -repeat=1000
|
||||
30
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
Executable file
30
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/sh
|
||||
EXE=build/bin/tile_example_moe_smoothquant
|
||||
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -prec=$pr_i -t=99 -h=13
|
||||
$EXE -prec=$pr_i -t=17 -h=16
|
||||
$EXE -prec=$pr_i -t=1 -h=100
|
||||
$EXE -prec=$pr_i -t=4 -h=128
|
||||
$EXE -prec=$pr_i -t=80 -h=127
|
||||
$EXE -prec=$pr_i -t=22 -h=255 -stride=256
|
||||
$EXE -prec=$pr_i -t=7 -h=599
|
||||
$EXE -prec=$pr_i -t=19 -h=512
|
||||
$EXE -prec=$pr_i -t=33 -h=313 -stride=1000
|
||||
$EXE -prec=$pr_i -t=11 -h=510
|
||||
$EXE -prec=$pr_i -t=171 -h=676 -stride=818
|
||||
$EXE -prec=$pr_i -t=91 -h=636
|
||||
$EXE -prec=$pr_i -t=12 -h=768 -stride=800
|
||||
$EXE -prec=$pr_i -t=100 -h=766 -stride=812
|
||||
$EXE -prec=$pr_i -t=31 -h=1024
|
||||
$EXE -prec=$pr_i -t=64 -h=1000 -stride=1004
|
||||
$EXE -prec=$pr_i -t=8 -h=1501
|
||||
$EXE -prec=$pr_i -t=3 -h=1826
|
||||
$EXE -prec=$pr_i -t=5 -h=2040
|
||||
$EXE -prec=$pr_i -t=7 -h=2734
|
||||
$EXE -prec=$pr_i -t=1 -h=3182
|
||||
$EXE -prec=$pr_i -t=9 -h=4096
|
||||
$EXE -prec=$pr_i -t=3 -h=8192
|
||||
$EXE -prec=$pr_i -t=1 -h=10547
|
||||
$EXE -prec=$pr_i -t=3 -h=17134
|
||||
done
|
||||
19
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
19
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
|
||||
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
|
||||
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
|
||||
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
|
||||
69
example/ck_tile/15_fused_moe/README.md
Normal file
69
example/ck_tile/15_fused_moe/README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# fused-moe
|
||||
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
|
||||

|
||||
|
||||
The benifit of this fused-moe:
|
||||
* 1.5~2x perf boost compared with current vllm solution
|
||||
* zero workspace to reduce memory footprint
|
||||
* much less kernel instance, easy to maintain
|
||||
|
||||
# Implementation and feature support
|
||||
## moe-sorting
|
||||
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
|
||||
|
||||
## moe-gemm
|
||||
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture:
|
||||

|
||||
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below:
|
||||

|
||||
|
||||
## optimization
|
||||
summary of the key design of this fused-moe operator:
|
||||
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation
|
||||
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer
|
||||
* fused scatter-gather for row index(same as vllm)
|
||||
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`.
|
||||
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck
|
||||
|
||||
##
|
||||
```
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
```
|
||||
52
example/ck_tile/15_fused_moe/fused_moe.hpp
Normal file
52
example/ck_tile/15_fused_moe/fused_moe.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
#include "fused_moegemm.hpp"
|
||||
|
||||
struct fused_moe_args
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* a_scale_ptr; // [m, 1], token scale
|
||||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
|
||||
const void* topk_ids_ptr; // [tokens, topk]
|
||||
const void* topk_weight_ptr; // [tokens, topk]
|
||||
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
|
||||
void* sorted_weight_ptr; // [max_num_tokens_padded]
|
||||
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
void* num_sorted_tiles_ptr; // [1]
|
||||
|
||||
ck_tile::index_t block_m; // block_m, used to devide the input
|
||||
ck_tile::index_t hidden_size; // k
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
||||
ck_tile::index_t num_experts; // number of groups
|
||||
ck_tile::index_t topk; // need this?
|
||||
|
||||
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moe_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
|
||||
84
example/ck_tile/15_fused_moe/fused_moegemm.hpp
Normal file
84
example/ck_tile/15_fused_moe/fused_moegemm.hpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
#include <string>
|
||||
|
||||
// this is only a convenient structure for creating an example
|
||||
// this is not part of the host API
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig;
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using GDataType = ck_tile::bf16_t;
|
||||
using DDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::fp16_t;
|
||||
using GDataType = ck_tile::fp16_t;
|
||||
using DDataType = ck_tile::fp16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::fp16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using GDataType = ck_tile::int8_t;
|
||||
using DDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moegemm_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
|
||||
20
example/ck_tile/15_fused_moe/fused_moesorting.hpp
Normal file
20
example/ck_tile/15_fused_moe/fused_moesorting.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
struct fused_moesorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
};
|
||||
|
||||
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s);
|
||||
80
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
Normal file
80
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
Normal file
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
|
||||
|
||||
auto o_data_bytes = [&]() {
|
||||
if(t.prec_o == "fp32")
|
||||
return 4;
|
||||
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
|
||||
return 2;
|
||||
else if(t.prec_o == "int8" || t.prec_o == "fp8")
|
||||
return 1;
|
||||
return 1;
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32"};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
t.prec_w,
|
||||
t.prec_o,
|
||||
t.prec_st,
|
||||
t.prec_sw,
|
||||
t.prec_sq,
|
||||
t.prec_kw,
|
||||
t.block_m,
|
||||
t.gate_only,
|
||||
t.fused_quant};
|
||||
auto a1 = fused_moegemm_args{
|
||||
a.a_ptr, // const void* a_ptr;
|
||||
a.a_scale_ptr, // const void* a_scale_ptr;
|
||||
a.g_ptr, // const void* g_ptr;
|
||||
a.d_ptr, // const void* d_ptr;
|
||||
a.g_scale_ptr, // const void* g_scale_ptr;
|
||||
a.d_scale_ptr, // const void* d_scale_ptr;
|
||||
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
|
||||
a.o_ptr, // void* o_ptr;
|
||||
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
|
||||
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
|
||||
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
|
||||
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
|
||||
a.hidden_size, // index_t hidden_size;
|
||||
a.intermediate_size, // index_t intermediate_size;
|
||||
a.num_tokens, // index_t num_tokens;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
a.stride_token // index_t stride_token;
|
||||
};
|
||||
|
||||
float r0 = -1;
|
||||
float r1 = -1;
|
||||
|
||||
float r = ck_tile::launch_kernel(
|
||||
s,
|
||||
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
|
||||
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
|
||||
|
||||
// keep unsupported case return negative
|
||||
if(r0 < 0 || r1 < 0)
|
||||
return -1;
|
||||
|
||||
return r;
|
||||
}
|
||||
33
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
33
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
// clang-format off
|
||||
float r = -1;
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
return r;
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
|
||||
template <typename Ts_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
{
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
|
||||
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
|
||||
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a);
|
||||
constexpr dim3 blocks = f_kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
static int printed = 0;
|
||||
|
||||
auto kargs = f_kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0 && printed == 0)
|
||||
{
|
||||
std::cout << ", " << f_kernel::GetName() << std::flush;
|
||||
printed = 1;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename I,
|
||||
typename W,
|
||||
typename O,
|
||||
typename ST,
|
||||
typename SW,
|
||||
typename SQ,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
{
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
|
||||
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
|
||||
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
|
||||
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
|
||||
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
|
||||
static constexpr ck_tile::index_t BI_ =
|
||||
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
|
||||
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
|
||||
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
|
||||
|
||||
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "fused_moegemm_api_internal.hpp"
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "fused_moegemm_api_internal.hpp"
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
return -1;
|
||||
}
|
||||
if(a.moe_buf_bytes % 16)
|
||||
{
|
||||
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
|
||||
return -1;
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
|
||||
switch(smem_io_unroll_num)
|
||||
{
|
||||
case(1): {
|
||||
MOE_SORTING_DISPATCH(1);
|
||||
}
|
||||
case(2): {
|
||||
MOE_SORTING_DISPATCH(2);
|
||||
}
|
||||
case(3): {
|
||||
MOE_SORTING_DISPATCH(3);
|
||||
}
|
||||
case(5): {
|
||||
MOE_SORTING_DISPATCH(5);
|
||||
}
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(7): {
|
||||
MOE_SORTING_DISPATCH(7);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(9): {
|
||||
MOE_SORTING_DISPATCH(9);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
case(11): {
|
||||
MOE_SORTING_DISPATCH(11);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
603
example/ck_tile/15_fused_moe/main.cpp
Normal file
603
example/ck_tile/15_fused_moe/main.cpp
Normal file
@@ -0,0 +1,603 @@
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
// TODO: padding?
|
||||
template <typename T>
|
||||
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
{
|
||||
assert(t.get_lengths().size() == 3);
|
||||
int b_ = t.get_lengths()[0];
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[2];
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "128", "num input tokens")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
.insert("bm", "32", "blocking factor for sorted tokens")
|
||||
.insert("tp", "8", "tensor parallel size")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "bf16", "input precision")
|
||||
.insert("prec_w", "bf16", "weight precision")
|
||||
.insert("prec_o", "bf16", "output precision")
|
||||
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
|
||||
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
|
||||
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
|
||||
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
|
||||
.insert("balance",
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"2",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
|
||||
// SQ:smooth-quant-type, KW:topk-weight-type
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int gate_only = arg_parser.get_int("gate_only");
|
||||
int api = arg_parser.get_int("api");
|
||||
int balance = arg_parser.get_int("balance");
|
||||
int tp = arg_parser.get_int("tp");
|
||||
int init = arg_parser.get_int("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
// w0 (Gate+Up or Gate only, N size)
|
||||
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
base_str += "x" + prec_w;
|
||||
if(prec_i != prec_o)
|
||||
base_str += "=" + prec_o;
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
auto api_str = [&]() {
|
||||
if(api == 0)
|
||||
return std::string("fmoe");
|
||||
else if(api == 1)
|
||||
return std::string("moeg");
|
||||
else if(api == 2)
|
||||
return std::string("moes");
|
||||
return std::string("");
|
||||
}();
|
||||
|
||||
auto stride_str = [&]() {
|
||||
if(stride == hidden_size)
|
||||
return std::string("");
|
||||
else
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
|
||||
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using GDataType = typename TypeConfig::GDataType;
|
||||
using DDataType = typename TypeConfig::DDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using AScaleDataType = typename TypeConfig::AScaleDataType;
|
||||
using GScaleDataType = typename TypeConfig::GScaleDataType;
|
||||
using DScaleDataType = typename TypeConfig::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
|
||||
using IndexDataType = typename TypeConfig::IndexDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
|
||||
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
|
||||
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
|
||||
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
|
||||
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
|
||||
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
|
||||
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
|
||||
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
|
||||
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
|
||||
{(max_num_tokens_padded + block_m - 1) / block_m});
|
||||
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
|
||||
|
||||
if(init == 0)
|
||||
{
|
||||
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
|
||||
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
|
||||
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
|
||||
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
|
||||
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
|
||||
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
|
||||
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
|
||||
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
|
||||
}
|
||||
else if(init == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
|
||||
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
|
||||
topk_weight_host);
|
||||
}
|
||||
else if(init == 2)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
|
||||
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
|
||||
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
|
||||
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
|
||||
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
|
||||
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
|
||||
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
|
||||
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
|
||||
}
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
|
||||
// do moe sorting
|
||||
if(balance)
|
||||
{
|
||||
int e_cnt = 0;
|
||||
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
|
||||
{
|
||||
topk_ids_host.mData[i] = e_cnt;
|
||||
e_cnt++;
|
||||
if(e_cnt >= experts)
|
||||
e_cnt = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
|
||||
}
|
||||
|
||||
// leave it here for future debug purpose
|
||||
#if 0
|
||||
a_host.loadtxt("../../ater/input_torch.txt");
|
||||
|
||||
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
|
||||
// topk_ids_host.savetxt("topk_ids_2.txt");
|
||||
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
g_host.loadtxt("../../ater/w1_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
d_host.loadtxt("../../ater/w2_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
|
||||
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
|
||||
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
|
||||
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
|
||||
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
|
||||
#endif
|
||||
auto cal_tflops = [&](auto ms) {
|
||||
double flop_gemm_0 =
|
||||
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_0 * hidden_size;
|
||||
double flop_gemm_1 =
|
||||
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_1 * hidden_size;
|
||||
return (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
// TODO: this method we use expert-by-expert view, just for reference
|
||||
auto cal_tbps = [&](auto ms) {
|
||||
double token_bytes =
|
||||
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ADataType);
|
||||
double w0_bytes = static_cast<double>(shared_intermediate_size_0) * experts * hidden_size *
|
||||
sizeof(GDataType);
|
||||
double w1_bytes = static_cast<double>(shared_intermediate_size_1) * experts * hidden_size *
|
||||
sizeof(DDataType);
|
||||
double o_bytes =
|
||||
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ODataType);
|
||||
double topk_weights_bytes = static_cast<double>(tokens) * topk * sizeof(TopkWeightDataType);
|
||||
// ignore index, they are too small
|
||||
|
||||
return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) /
|
||||
(static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
if(api == 0)
|
||||
{
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
|
||||
ck_tile::DeviceMem topk_weight_buf(topk_weight_host);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(
|
||||
sorted_token_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moe_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
block_m,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
float ave_time = fused_moe(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
return pass;
|
||||
}
|
||||
else if(api == 1)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
//
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
||||
|
||||
fused_moegemm_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
BIN
example/ck_tile/15_fused_moe/misc/moe-0.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 75 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-1.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 90 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-2.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 124 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-3.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user