Files
composable_kernel/example/01_gemm/gemm_wmma_fp16.cpp
Aviral Goel d85f065b15 chore(copyright): update copyright header for example directory (#3273)
* chore(copyright): update copyright header for codegen directory

* chore(copyright): update copyright header for example directory
2025-11-24 18:02:41 -08:00

85 lines
2.9 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }