mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* Fix example * fix build error * update pk_i4 & moe test case * fix all instance build (examples) * fix batched_gemm_gemm (example) * disable example_gemm_bias_softmax_gemm_permute on gfx11 * remove unnecessary disable gfx11 * update tests * update tests2
83 lines
4.1 KiB
C++
83 lines
4.1 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "common.hpp"
|
|
|
|
template <ck::index_t... Is>
|
|
using S = ck::Sequence<Is...>;
|
|
|
|
using F16 = ck::half_t;
|
|
using F32 = float;
|
|
|
|
using Row = ck::tensor_layout::gemm::RowMajor;
|
|
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
|
|
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
|
|
using ADataType = F16;
|
|
using BDataType = F16;
|
|
using AccDataType = F32;
|
|
using CShuffleDataType = F32;
|
|
using DDataType = F16;
|
|
using EDataType = F16;
|
|
|
|
using ALayout = Row;
|
|
using BLayout = Col;
|
|
using DLayout = Row;
|
|
using ELayout = Row;
|
|
|
|
using AElementOp = PassThrough;
|
|
using BElementOp = PassThrough;
|
|
using CDEElementOp = Add;
|
|
|
|
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
|
|
|
using DeviceOpInstance =
|
|
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
|
|
BLayout,
|
|
ck::Tuple<DLayout>,
|
|
ELayout,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
ck::Tuple<DDataType>,
|
|
EDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp,
|
|
GemmSpec,
|
|
1,
|
|
256,
|
|
256,
|
|
128,
|
|
32,
|
|
8,
|
|
8,
|
|
16,
|
|
16,
|
|
8,
|
|
4,
|
|
S<4, 64, 1>,
|
|
S<1, 0, 2>,
|
|
S<1, 0, 2>,
|
|
2,
|
|
8,
|
|
8,
|
|
1,
|
|
S<4, 64, 1>,
|
|
S<1, 0, 2>,
|
|
S<1, 0, 2>,
|
|
2,
|
|
8,
|
|
8,
|
|
1,
|
|
1,
|
|
1,
|
|
S<1, 32, 1, 8>,
|
|
4>;
|
|
|
|
#include "run_gemm_add_example_xdl.inc"
|
|
|
|
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }
|