mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Merge commit '8f75869408210cb85e9eb7ff639c4c9dad1331cb' into develop
This commit is contained in:
@@ -38,16 +38,28 @@ add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contracti
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
|
||||
|
||||
# FP16
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp16 contraction_bilinear_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp16 contraction_scale_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
|
||||
|
||||
# BF16
|
||||
add_example_executable(example_contraction_bilinear_xdl_bf16 contraction_bilinear_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_bf16 contraction_scale_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
|
||||
|
||||
|
||||
86
example/26_contraction/contraction_bilinear_xdl_bf16.cpp
Normal file
86
example/26_contraction/contraction_bilinear_xdl_bf16.cpp
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "common_instances.hpp"
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using DDataType = BF16;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = BF16;
|
||||
using ComputeDataType = BF16;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
#include "run_contraction_bilinear_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); }
|
||||
86
example/26_contraction/contraction_bilinear_xdl_fp16.cpp
Normal file
86
example/26_contraction/contraction_bilinear_xdl_fp16.cpp
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "common_instances.hpp"
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F16;
|
||||
using ComputeDataType = F16;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
#include "run_contraction_bilinear_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); }
|
||||
85
example/26_contraction/contraction_scale_xdl_bf16.cpp
Normal file
85
example/26_contraction/contraction_scale_xdl_bf16.cpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "common_instances.hpp"
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
using ComputeDataType = BF16;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
#include "run_contraction_scale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); }
|
||||
85
example/26_contraction/contraction_scale_xdl_fp16.cpp
Normal file
85
example/26_contraction/contraction_scale_xdl_fp16.cpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "common_instances.hpp"
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
using ComputeDataType = F16;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
#include "run_contraction_scale_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); }
|
||||
@@ -235,13 +235,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
if constexpr(std::is_same_v<EDataType, F32>)
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -218,13 +218,20 @@ int run_contraction_scale_example(int argc, char* argv[])
|
||||
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
if constexpr(std::is_same_v<EDataType, F32>)
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -315,7 +315,7 @@ class FmhaFwdApiTrait:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
if self.bm0 == max_bm0:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
@@ -847,11 +847,6 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
|
||||
and kernel_ctx.tile.F_bm0 != 128
|
||||
)
|
||||
or (
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
|
||||
and kernel_ctx.pipeline.tag != "qr_async"
|
||||
and kernel_ctx.tile.F_bk0 == 64
|
||||
)
|
||||
):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
@@ -947,7 +942,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
|
||||
Reference in New Issue
Block a user