mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Merge remote-tracking branch 'origin/develop' into andriy/ck_tile/basic-tutorials
This commit is contained in:
@@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8, 32, 1>,
|
||||
S<8, 16, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -174,6 +174,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
|
||||
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));
|
||||
|
||||
@@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8, 32, 1>,
|
||||
S<8, 16, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -133,7 +133,7 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -170,6 +170,28 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));
|
||||
|
||||
@@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<4, 64, 1>,
|
||||
S<4, 16, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -233,6 +233,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideD = f_get_default_stride(M, N, StrideD, DLayout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
@@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8, 32, 1>,
|
||||
S<8, 16, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
@@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -173,6 +173,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
|
||||
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
|
||||
@@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convinvscale)
|
||||
add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# WMMA
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
add_custom_target(example_convnd_activ_wmma_convinvscale)
|
||||
add_example_executable(example_convnd_fwd_wmma_convinvscale_fp8 convnd_fwd_wmma_convinvscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convinvscale example_convnd_fwd_wmma_convinvscale_fp8)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convinvscale_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = ck::f8_t;
|
||||
using BComputeDataType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvInvscale;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
DsLayout, // DsLayout (empty tuple for ConvInvScale)
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
DsDataType, // DsDataType (empty tuple)
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convinvscale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -15,3 +15,19 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8)
|
||||
endif()
|
||||
|
||||
# WMMA
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
add_custom_target(example_convnd_activ_wmma_convscale)
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_fp8 convnd_fwd_wmma_convscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_bf8 convnd_fwd_wmma_convscale_bf8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_fp8_bf8 convnd_fwd_wmma_convscale_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8_bf8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_bf8_fp8 convnd_fwd_wmma_convscale_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8_fp8)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::bf8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = InDataType;
|
||||
using BComputeDataType = AComputeDataType;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScale;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
DsLayout, // DsLayout (empty tuple for ConvScale)
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
DsDataType, // DsDataType (empty tuple)
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convscale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::bf8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = ck::bf8_t;
|
||||
using BComputeDataType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScale;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
DsLayout, // DsLayout (empty tuple for ConvScale)
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
DsDataType, // DsDataType (empty tuple)
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convscale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = ck::f8_t;
|
||||
using BComputeDataType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScale;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
DsLayout, // DsLayout (empty tuple for ConvScale)
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
DsDataType, // DsDataType (empty tuple)
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convscale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convscale_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = ck::f8_t;
|
||||
using BComputeDataType = ck::bf8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScale;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
DsLayout, // DsLayout (empty tuple for ConvScale)
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
DsDataType, // DsDataType (empty tuple)
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convscale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_add)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# WMMA
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
add_custom_target(example_convnd_activ_wmma_convscale_add)
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_add_fp8 convnd_fwd_wmma_convscale_add_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale_add example_convnd_fwd_wmma_convscale_add_fp8)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "convnd_fwd_convscale_add_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = float;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = ck::f8_t;
|
||||
using BComputeDataType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScaleAdd;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
ck::Tuple<DsLayout>, // DsLayout
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
ck::Tuple<DsDataType>, // DsDataType
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convscale_add_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -8,4 +8,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# WMMA
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
add_custom_target(example_convnd_activ_wmma_convscale_reduce)
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_amax_fp8 convnd_fwd_wmma_convscale_amax_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale_reduce example_convnd_fwd_wmma_convscale_amax_fp8)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convscale_reduce_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using ConvOutDataType = float; // data type of convolution result
|
||||
using OutDataType = ck::f8_t; // data type of final result
|
||||
using AComputeDataType = ck::f8_t;
|
||||
using BComputeDataType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScale;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ConvOutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -6,3 +6,10 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8)
|
||||
endif()
|
||||
|
||||
# WMMA
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
add_custom_target(example_convnd_activ_wmma_convscale_relu)
|
||||
add_example_executable(example_convnd_fwd_wmma_convscale_relu_fp8 convnd_fwd_wmma_convscale_relu_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_wmma_convscale_relu example_convnd_fwd_wmma_convscale_relu_fp8)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_convscale_relu_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using OutDataType = ck::f8_t;
|
||||
using AComputeDataType = ck::f8_t;
|
||||
using BComputeDataType = ck::f8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ConvScaleRelu;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
DsLayout, // DsLayout (empty tuple for ConvScaleRelu)
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
DsDataType, // DsDataType (empty tuple)
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
#include "run_convnd_fwd_convscale_relu_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
@@ -37,4 +37,10 @@ add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fw
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16)
|
||||
# Logistic
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16)
|
||||
|
||||
# WMMA
|
||||
add_custom_target(example_convnd_activ_dynamic_unary_wmma)
|
||||
# PassThrough
|
||||
add_example_executable(example_convnd_fwd_wmma_dynamic_passthrough_fp16 convnd_fwd_wmma_dynamic_passthrough_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_passthrough_fp16)
|
||||
|
||||
@@ -0,0 +1,245 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
constexpr ck::index_t NDimSpatial = 3;
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AComputeDataType = ck::half_t;
|
||||
using BComputeDataType = ck::half_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Use correct tensor layouts for WMMA (matching working tests)
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DynamicElementOp = ck::tensor_operation::element_wise::DynamicUnaryOp;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
DynamicElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
true, // UseThreadTileTransfer
|
||||
AComputeDataType, // AComputeDataType
|
||||
BComputeDataType, // BComputeDataType
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename DeviceConvNDFwdInstance>
|
||||
bool run_grouped_conv(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
const ck::HostTensorDescriptor& in_g_n_c_wis_desc,
|
||||
const ck::HostTensorDescriptor& wei_g_k_c_xs_desc,
|
||||
const ck::HostTensorDescriptor& out_g_n_k_wos_desc,
|
||||
const InElementOp& in_element_op,
|
||||
const WeiElementOp& wei_element_op,
|
||||
const OutElementOp& out_element_op)
|
||||
{
|
||||
ck::Tensor<InDataType> in(in_g_n_c_wis_desc);
|
||||
ck::Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
|
||||
ck::Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
|
||||
ck::Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
|
||||
|
||||
std::cout << "in: " << in.mDesc << std::endl;
|
||||
std::cout << "wei: " << wei.mDesc << std::endl;
|
||||
std::cout << "out: " << out_host.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
|
||||
}
|
||||
|
||||
ck::DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
in_device_buf.ToDevice(in.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
// do Conv
|
||||
auto conv = DeviceConvNDFwdInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error("The device op with the specified compilation parameters does "
|
||||
"not support this convolution problem.");
|
||||
}
|
||||
|
||||
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< conv.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in,
|
||||
wei,
|
||||
out_host,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
out_device_buf.FromDevice(out_device.mData.data());
|
||||
|
||||
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-3, 0.1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp"
|
||||
|
||||
#include "../run_convnd_activ_dynamic_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough out_element_op;
|
||||
return !run_convnd_example(argc, argv, out_element_op);
|
||||
}
|
||||
@@ -47,6 +47,12 @@ bool run_convnd_example(int argc, char* argv[], const OutElementOp& out_element_
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
if(std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::SoftRelu> &&
|
||||
init_method != 2)
|
||||
{
|
||||
std::cout << "Running SoftRelu op with int initialization. Risk of overflow.\n\n";
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ struct bias_info
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
||||
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
|
||||
@@ -78,12 +78,14 @@ QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -630,6 +630,7 @@ class KernelComponentFactory:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
return {
|
||||
@@ -676,7 +677,7 @@ class KernelComponentFactory:
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
@@ -739,6 +740,10 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
|
||||
@@ -602,6 +602,13 @@ struct fmha_batch_prefill_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
|
||||
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
|
||||
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
|
||||
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
@@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ struct mask_info
|
||||
return area;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
|
||||
@@ -8,12 +8,16 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -28,6 +32,8 @@ struct quant_scale_info
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::kv_blockscale)
|
||||
os << "kvbs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -45,6 +51,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else if(str == "kvbs" || str == "3")
|
||||
{
|
||||
info.type = quant_scale_enum::kv_blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
@@ -58,3 +68,4 @@ struct quant_scale_info
|
||||
return os;
|
||||
}
|
||||
};
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -59,7 +59,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -202,7 +203,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -44,7 +44,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -210,7 +211,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -21,7 +21,6 @@ if(has_supported_gpu)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
|
||||
add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
@@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack = 16 * packed_size; // fp4:32 or fp8:16
|
||||
int NLane = N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
int KPack =
|
||||
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
int NLane = N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
@@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
|
||||
{
|
||||
throw std::runtime_error("fp6xfp6 is not supported.");
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
|
||||
ck_tile::pk_fp6x16_t,
|
||||
ck_tile::fp16_t,
|
||||
MXfp6_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
|
||||
@@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp6_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp8_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
|
||||
@@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST)
|
||||
set(C_LAYOUT ROW)
|
||||
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")
|
||||
|
||||
# foreach(PERSISTENT false true)
|
||||
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
|
||||
foreach(PERSISTENT false)
|
||||
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8)
|
||||
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
|
||||
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
|
||||
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
|
||||
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
using FP4 = ck_tile::pk_fp4_t;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using FP6 = ck_tile::pk_fp6x16_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
|
||||
@@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
|
||||
{
|
||||
auto a_buffer_bytes = a_host.get_element_space_size_in_bytes();
|
||||
auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes();
|
||||
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
|
||||
std::vector<int8_t> random_bufA(a_buffer_bytes);
|
||||
std::vector<int8_t> random_bufB(b_buffer_bytes);
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<int> dis(1, 4);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
|
||||
for(size_t i = 0; i < a_buffer_bytes; ++i)
|
||||
random_bufA[i] = static_cast<int8_t>(dis(gen));
|
||||
|
||||
for(size_t i = 0; i < b_buffer_bytes; ++i)
|
||||
random_bufB[i] = static_cast<int8_t>(dis(gen));
|
||||
|
||||
memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes);
|
||||
memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! Unexpected init_method");
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! Unexpected init_method");
|
||||
}
|
||||
}
|
||||
|
||||
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
|
||||
|
||||
@@ -134,5 +134,65 @@ static auto _ = []() {
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"or bf8i4; for ABQuant: fp8, bf8")
|
||||
"or bf8i4; for ABQuant: fp8, bf8, fp4")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -80,7 +80,8 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool APreshuffleQuant = false;
|
||||
static constexpr bool BPreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
@@ -157,7 +158,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool APreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -187,7 +189,7 @@ template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
|
||||
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -218,7 +220,7 @@ template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -272,7 +274,7 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -33,11 +34,11 @@ template <typename GemmConfig,
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::ADataType>;
|
||||
constexpr bool transpose_c =
|
||||
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
|
||||
// Use automatically determined compute type from
|
||||
using ComputeDataType = void;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -50,14 +51,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
GemmConfig::APreshuffleQuant,
|
||||
GemmConfig::BPreshuffleQuant,
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout, // for AQLayout
|
||||
BQLayout, // for BQLayout
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
transpose_c,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
@@ -73,12 +75,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>>;
|
||||
|
||||
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
@@ -146,7 +151,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
std::conditional_t<GemmConfig::APreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
@@ -180,30 +185,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
@@ -212,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
// Split-K validation is handled by Kernel::IsSupportedArgument
|
||||
// Split-K is only supported for BQuantGrouped without preshuffle
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
@@ -390,8 +390,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false")
|
||||
<< " : "
|
||||
<< " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false")
|
||||
<< " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
@@ -536,21 +538,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
// Create BQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
@@ -561,8 +555,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
@@ -598,18 +591,26 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
@@ -657,184 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
|
||||
for(ck_tile::index_t row = 0;
|
||||
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
|
||||
++row)
|
||||
{
|
||||
for(ck_tile::index_t col = 0;
|
||||
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
|
||||
++col)
|
||||
{
|
||||
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
|
||||
}
|
||||
}
|
||||
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
@@ -870,7 +693,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
if constexpr(GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
|
||||
@@ -929,7 +752,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
|
||||
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
@@ -940,7 +763,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
|
||||
}
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
else if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
@@ -988,10 +811,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
std::cout << "Performing CPU verification..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Track start time for reference operation
|
||||
auto start_reference_tick = std::chrono::high_resolution_clock::now();
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
@@ -1055,6 +882,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
|
||||
// Track where we stop reference calculation, and start verification
|
||||
auto start_verification_tick = std::chrono::high_resolution_clock::now();
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
@@ -1065,6 +895,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
// "Stop" our timer
|
||||
auto verification_finished_tick = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
@@ -1072,6 +905,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
// Calculate and display reference timing
|
||||
using DurationType = std::chrono::duration<double>;
|
||||
double reference_sec = std::chrono::duration_cast<DurationType>(verification_finished_tick -
|
||||
start_reference_tick)
|
||||
.count();
|
||||
double verification_sec = std::chrono::duration_cast<DurationType>(
|
||||
verification_finished_tick - start_verification_tick)
|
||||
.count();
|
||||
float reference_msec = static_cast<float>(reference_sec * 1e3);
|
||||
float verification_msec = static_cast<float>(verification_sec * 1e3);
|
||||
|
||||
std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took "
|
||||
<< reference_msec << "ms, verification took " << verification_msec << "ms."
|
||||
<< std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
@@ -1102,6 +950,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_fp4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf16_t>)
|
||||
@@ -1121,7 +970,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
@@ -1142,7 +991,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
|
||||
!GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
|
||||
156
example/ck_tile/50_sparse_attn/CMakeLists.txt
Normal file
156
example/ck_tile/50_sparse_attn/CMakeLists.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# CMakeLists.txt for sparse attention (Jenga and VSA)
|
||||
|
||||
# Use SUPPORTED_GPU_TARGETS directly
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}")
|
||||
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")
|
||||
|
||||
# Code generation scripts
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
# ============================================================================
|
||||
# Jenga Sparse Attention
|
||||
# ============================================================================
|
||||
set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd_jenga
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
# Generate list of Jenga kernels (at configure time, only list)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate Jenga kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS)
|
||||
|
||||
# Generate Jenga kernel source files at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile Jenga Sparse Attention kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}")
|
||||
|
||||
# Jenga Instances
|
||||
set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances")
|
||||
|
||||
add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARSE_ATTN_JENGA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Jenga Example executable
|
||||
set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# VSA Sparse Attention
|
||||
# ============================================================================
|
||||
set(SPARSE_ATTN_VSA_CODE_GEN_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd_vsa
|
||||
--receipt 600
|
||||
)
|
||||
|
||||
# Generate list of VSA kernels (at configure time, only list)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate VSA kernel list")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS)
|
||||
|
||||
# Generate VSA kernel source files at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate CK Tile VSA Sparse Attention kernels"
|
||||
)
|
||||
|
||||
message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}")
|
||||
|
||||
# VSA Instances
|
||||
set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances")
|
||||
|
||||
add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARSE_ATTN_VSA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# VSA Example executable
|
||||
set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
3
example/ck_tile/50_sparse_attn/codegen/__init__.py
Normal file
3
example/ck_tile/50_sparse_attn/codegen/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
73
example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
Normal file
73
example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16": "FmhaSparseFwdFp16",
|
||||
"bf16": "FmhaSparseFwdBf16",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no": "FmhaMasks::NoMask",
|
||||
"causal": "FmhaMasks::CausalMask",
|
||||
"generic": "FmhaMasks::GenericMask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
MODE_MAP = {"batch": "false"}
|
||||
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
|
||||
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
3
example/ck_tile/50_sparse_attn/codegen/ops/__init__.py
Normal file
3
example/ck_tile/50_sparse_attn/codegen/ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py
Normal file
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py
Normal file
@@ -0,0 +1,867 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
|
||||
#include "kernel/fmha_fwd_jenga_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_jenga_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128:
|
||||
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# jenga fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline( # fmt: skip
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=2,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py
Normal file
867
example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py
Normal file
@@ -0,0 +1,867 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
import os.path as path
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cpp_symbol_map import (
|
||||
BOOL_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
MODE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
PIPELINE_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
)
|
||||
|
||||
GEN_DIR = ""
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
|
||||
#include "kernel/fmha_fwd_vsa_kernel.hpp"
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
|
||||
# - Group mode: NOT supported (batch mode only)
|
||||
# - Bias: NOT supported (NO_BIAS only)
|
||||
# - LSE output: NOT supported (false only)
|
||||
# - Dropout: NOT supported (false only)
|
||||
# - Logits soft-cap: NOT supported (false only)
|
||||
# - FP8 static quantization: NOT supported (NO_SCALE only)
|
||||
# The template below hardcodes these unsupported features accordingly.
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
|
||||
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false, // has_logits_soft_cap - NOT supported
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
|
||||
false, // store_lse - NOT supported
|
||||
false, // has_dropout - NOT supported
|
||||
false, // has_randval - NOT supported
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
|
||||
{F_occupancy},
|
||||
false>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << "{F_kernel_name}" << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
|
||||
return fmha_vsa_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
tr_load: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128:
|
||||
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_trload: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
n += "_nskip"
|
||||
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
|
||||
|
||||
per_tr_load = str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case = str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits = [
|
||||
t
|
||||
for t in self.pool[dtype][(hdim, hdim_v)]
|
||||
if tr_load == t.tr_load
|
||||
]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
# F_logits removed - hardcoded to false (NOT supported)
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
|
||||
F_if="if",
|
||||
F_trload_cond=tr_load_cond_map[tr_load],
|
||||
F_dtype_case=per_dtypes,
|
||||
)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=CppConstraint)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
# kernel_body removed - unused
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
# F_logits removed - hardcoded to false in template (NOT supported)
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_kernel_name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask in itertools.product(
|
||||
["f"], # logits soft-cap NOT supported, always false
|
||||
get_mask_map(mask_impl).keys(),
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# vsa fmha only supports dim <= 192 for now.
|
||||
continue
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_vsa",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = (
|
||||
CustomFactory
|
||||
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
|
||||
else KernelComponentFactory
|
||||
)
|
||||
|
||||
# Only generate fp16/bf16 kernels for now.
|
||||
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=1,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
328
example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
Normal file
328
example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
Normal file
@@ -0,0 +1,328 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
#include "01_fmha/mask.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
namespace ck_tile {
|
||||
inline bool is_load_tr_supported() { return is_gfx95_supported(); }
|
||||
} // namespace ck_tile
|
||||
|
||||
struct FmhaSparseFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaSparseFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaSparseFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
// Note: The following types are required by BlockFmhaPipelineProblem but not used
|
||||
// by sparse attention (bias, dropout, LSE are not supported).
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
// Note: The following types are required by BlockFmhaPipelineProblem but not used
|
||||
// by sparse attention (bias, dropout, LSE are not supported).
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// jenga
|
||||
struct fmha_jenga_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
// Dropout is not supported for sparse attention; keep args minimal.
|
||||
};
|
||||
|
||||
// vsa
|
||||
struct fmha_vsa_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk]
|
||||
const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk]
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
// Dropout is not supported for sparse attention; keep args minimal.
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.block_relation_onehot_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.lut_ptr,
|
||||
args.valid_block_num_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_>
|
||||
struct fmha_jenga_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
};
|
||||
|
||||
struct fmha_jenga_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
// VSA uses the same traits structure as Jenga; aliases for clarity
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_>
|
||||
using fmha_vsa_fwd_traits_ = fmha_jenga_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_>;
|
||||
|
||||
using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
|
||||
|
||||
float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&);
|
||||
166
example/ck_tile/50_sparse_attn/generate.py
Normal file
166
example/ck_tile/50_sparse_attn/generate.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
unwanted_prefix = "fmha_"
|
||||
handlers = dict(
|
||||
[
|
||||
(
|
||||
op.__name__[len(unwanted_prefix) :]
|
||||
if op.__name__.startswith(unwanted_prefix)
|
||||
else op.__name__,
|
||||
(op.list_blobs, op.write_blobs),
|
||||
)
|
||||
for op in ops
|
||||
]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"-a",
|
||||
"--api",
|
||||
default="fwd_jenga",
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default="",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n"
|
||||
+ " 1: generate more instance to cover all hdim\n"
|
||||
+ " 2: Only generate instance for Flash attention integration\n"
|
||||
+ " 4: Only generate instance for PyTorch integration\n"
|
||||
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
|
||||
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
|
||||
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
|
||||
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default="-1",
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
|
||||
+ "eg. --optdim=32,64,128,256",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.direction.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
else:
|
||||
write_blobs(
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
199
example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp
Normal file
199
example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"Jenga sparse attention supports fp16/bf16 only.");
|
||||
// Determine data type string based on template parameter
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
// Create device memory and copy data to device
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
// Use device buffer pointers instead of host tensor data pointers
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // batch mode only
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // batch mode only
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
// Dropout not supported for sparse attention.
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_jenga_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_jenga_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
fmha_jenga_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
// Copy output back to host without changing tensor shape
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
jenga_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
jenga_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<uint8_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
48
example/ck_tile/50_sparse_attn/jenga_sparse_attention.h
Normal file
48
example/ck_tile/50_sparse_attn/jenga_sparse_attention.h
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_> vsa_sparse_attention(
|
||||
const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
423
example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp
Normal file
423
example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp
Normal file
@@ -0,0 +1,423 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Test for jenga_sparse_attention function
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Get error tolerance based on data type
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
// bf16 accumulation/rounding can be noisier in sparse patterns
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
|
||||
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
ck_tile::index_t BLKQ = block_size;
|
||||
ck_tile::index_t BLKK = block_size;
|
||||
|
||||
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "Jenga kernel instances are generated for block_size=128 and hdim=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculate number of Q and K blocks
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[Jenga Sparse Attention Test]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
|
||||
<< std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " sparsity: " << sparsity << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors (using BHSD layout when i_perm=true)
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
// Block relation onehot: [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
|
||||
// Initialize tensors with random values
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Initialize block_relation_onehot with sparse pattern
|
||||
std::mt19937 rng(seed + 100);
|
||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||
ck_tile::index_t total_blocks = 0;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
bool is_diagonal = (qb == kb && qb < num_k_blocks);
|
||||
bool random_active = (dist(rng) > sparsity);
|
||||
|
||||
if(is_diagonal || random_active)
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
|
||||
active_blocks++;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
// Run kernel
|
||||
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validation
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
// Compare results
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
486
example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp
Normal file
486
example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp
Normal file
@@ -0,0 +1,486 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Test for vsa_sparse_attention function
|
||||
// Based on the Python test: test_jenga_attention.py
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel)
|
||||
template <typename T>
|
||||
void block_map_to_lut(
|
||||
const ck_tile::HostTensor<T>& block_map, // [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<int32_t>& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel
|
||||
ck_tile::HostTensor<int32_t>& valid_block_num, // [B, H, Q_blocks] - int32_t for kernel
|
||||
ck_tile::index_t num_block_k)
|
||||
{
|
||||
auto lengths = block_map.get_lengths();
|
||||
ck_tile::index_t B = lengths[0];
|
||||
ck_tile::index_t H = lengths[1];
|
||||
ck_tile::index_t Q = lengths[2];
|
||||
|
||||
for(ck_tile::index_t b = 0; b < B; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < H; ++h)
|
||||
{
|
||||
for(ck_tile::index_t q = 0; q < Q; ++q)
|
||||
{
|
||||
int32_t valid_count = 0;
|
||||
int32_t prev_block = 0;
|
||||
|
||||
for(ck_tile::index_t k = 0; k < num_block_k; ++k)
|
||||
{
|
||||
T cur_block = block_map(b, h, q, k);
|
||||
if(static_cast<float>(cur_block) > 0.5f)
|
||||
{ // Check if block is active
|
||||
lut(b, h, q, valid_count) = static_cast<int32_t>(k - prev_block);
|
||||
valid_count++;
|
||||
prev_block = static_cast<int32_t>(k);
|
||||
}
|
||||
}
|
||||
valid_block_num(b, h, q) = valid_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get error tolerance based on data type
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
// bf16 accumulation/rounding can be noisier in sparse patterns
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
|
||||
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
ck_tile::index_t BLKQ = block_size;
|
||||
ck_tile::index_t BLKK = block_size;
|
||||
|
||||
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "VSA kernel instances are generated for block_size=128 and hdim=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculate number of Q and K blocks
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[VSA Sparse Attention Test]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
|
||||
<< std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " sparsity: " << sparsity << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors (using BHSD layout when i_perm=true)
|
||||
// Q: [B, H, S_q, D]
|
||||
// K: [B, H_k, S_k, D]
|
||||
// V: [B, H_k, S_k, D_v]
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
// Block relation onehot: [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
|
||||
// LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel
|
||||
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
|
||||
|
||||
// Initialize tensors with random values
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Initialize block_relation_onehot with sparse pattern
|
||||
std::mt19937 rng(seed + 100);
|
||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||
ck_tile::index_t total_blocks = 0;
|
||||
ck_tile::index_t active_blocks = 0;
|
||||
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
// Each Q block always attends to its diagonal K block (if exists)
|
||||
// Plus random blocks based on sparsity
|
||||
bool is_diagonal = (qb == kb && qb < num_k_blocks);
|
||||
bool random_active = (dist(rng) > sparsity);
|
||||
|
||||
if(is_diagonal || random_active)
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
|
||||
active_blocks++;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
// Convert block_relation_onehot to LUT format
|
||||
std::cout << "Converting block map to LUT format..." << std::endl;
|
||||
block_map_to_lut(block_relation_onehot, lut_host, valid_block_num_host, num_k_blocks);
|
||||
|
||||
// vsa_sparse_attention handles device memory internally
|
||||
|
||||
// Run kernel
|
||||
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
// Print kernel name once by invoking with log_level=1.
|
||||
// This is separate from warmup/benchmark to avoid polluting timing.
|
||||
if(kname)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: vsa_sparse_attention already returns output in output_host
|
||||
|
||||
// Validation
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
|
||||
// Compute scale factor
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
// Run reference implementation
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
// Compare results
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
205
example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp
Normal file
205
example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"VSA sparse attention supports fp16/bf16 only.");
|
||||
// Determine data type string based on template parameter
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
// Create device memory and copy data to device
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
lut_buf.ToDevice(TKV_block_idx.data());
|
||||
valid_block_num_buf.ToDevice(TKV_blocks.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
// Use device buffer pointers instead of host tensor data pointers
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // batch mode only
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // batch mode only
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
// Dropout not supported for sparse attention.
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_vsa_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_vsa_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
fmha_vsa_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
// Copy output back to host without changing tensor shape
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
vsa_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
vsa_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
@@ -32,3 +32,5 @@ add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
add_subdirectory(99_toy_example)
|
||||
add_subdirectory(99_toy_tutorial)
|
||||
add_subdirectory(50_sparse_attn)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user