Merge branch 'develop' into lwpck-3984

This commit is contained in:
khuagarw
2025-11-14 21:36:36 +00:00
186 changed files with 1127 additions and 542 deletions

View File

@@ -2,7 +2,7 @@
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
## Composable Kernel 1.1.0 for ROCm 7.2.0
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM

View File

@@ -105,7 +105,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()
list(APPEND gpu_list_tf32 gfx942)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)

View File

@@ -21,7 +21,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()
list(APPEND gpu_list_tf32 gfx942)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)

View File

@@ -77,7 +77,7 @@ inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 1e-2;
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{

View File

@@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32)
set(target 1)
endif()
endforeach()

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#define EXAMPLE_WITH_COMPUTE_DATATYPE
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ComputeDataType = ck::tf32_t;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>;
// clang-format on
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
#undef EXAMPLE_WITH_COMPUTE_DATATYPE

View File

@@ -3,6 +3,11 @@
#pragma once
// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
@@ -231,7 +236,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
CDEElementOp,
ComputeDataType>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
@@ -253,7 +259,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]);
#else
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
pass &= ck::utils::check_err<decltype(c_device_tensors[i]),
decltype(c_host_tensors[i]),
ComputeDataType>(c_device_tensors[i], c_host_tensors[i]);
#endif
}
}

View File

@@ -47,7 +47,7 @@ static constexpr inline auto is_row_major(Layout layout_)
// mfma_type, 0:32x32, 1:16x16
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b_v0(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];

View File

@@ -103,7 +103,7 @@ int run_flatmm_example_with_layouts(int argc,
}
else
{
return shuffle_b<FlatmmConfig>(b_origin_host);
return shuffle_b_v0<FlatmmConfig>(b_origin_host);
}
}();
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());

View File

@@ -129,7 +129,10 @@ inline bool is_wmma_supported()
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
}
inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
inline bool is_tf32_supported()
{
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
}
} // namespace ck
#endif

View File

@@ -168,8 +168,8 @@ typename std::enable_if<
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-5)
double rtol = 5e-4,
double atol = 5e-4)
{
if(out.size() != ref.size())
{

View File

@@ -94,7 +94,8 @@ template <typename ALayout,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();

View File

@@ -795,7 +795,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
str << "DeviceGemmMultiD_Xdl_CShuffle_V3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
@@ -817,7 +817,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
<< GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages << ", "
<< "AK1: "
<< AK1 << ", "
<< "BK1: "
<< BK1;
// clang-format on
return str.str();

View File

@@ -134,7 +134,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
@@ -145,7 +146,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceGroupedGemm_Xdl;
GET_NXDL_PER_WAVE_IMPL
@@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = ADataType;
// GridwiseGemm
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<

View File

@@ -1145,6 +1145,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
__device__ static bool constexpr IsValidCompilationParameter()
{
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 on pre-gfx950
if constexpr((static_cast<bool>(Arch::is_gfx950_build) == false) &&
(AK1Number >= 32 || BK1Number >= 32))
{
return false;
}
constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter<
BlockSize,
MPerBlock,

View File

@@ -80,8 +80,10 @@ enum struct MfmaInstr
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4,
mfma_f32_16x16x8xf32, // tf32
mfma_f32_32x32x4xf32,
mfma_f32_16x16x8xf32, // tf32 on gfx942
mfma_f32_32x32x4xf32, // tf32 on gfx942
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
// gfx11
wmma_f32_16x16x16_f16,
wmma_f32_16x16x16_bf16,
@@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
// gfx11
struct mfma_type_gfx11_base
{
@@ -1275,12 +1322,14 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<tf32_t, 32, 32>()
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_32x32x4xf32;
#else
@@ -1289,12 +1338,14 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<tf32_t, 16, 16>()
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_16x16x8xf32;
#else
@@ -2185,6 +2236,10 @@ struct XdlopsGemm
(is_same<base_type, int8_t>::value && KPack <= 8) ||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
is_same<additional_type, pk_i4_t>::value)
#if defined(__gfx950__)
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
#endif
? true
: false;
static constexpr auto mfma = MfmaSelector<base_type,

View File

@@ -10,6 +10,25 @@ namespace ck {
#define __gfx94__
#endif
// Helper function to convert float vector to bf16 vectors (big and small parts)
// This is used by both tf32 and xf32 implementations
template <index_t VecSize>
__device__ __forceinline__ void
convert_float_to_bf16_pairs(const vector_type<float, VecSize>& reg_f32,
vector_type<bhalf_t, VecSize>& reg_bf16_big,
vector_type<bhalf_t, VecSize>& reg_bf16_small)
{
static_for<0, VecSize, 1>{}([&](auto k) {
using IK = Number<k>;
reg_bf16_big.template AsType<bhalf_t>()(k) =
type_convert<bhalf_t, float>(reg_f32.template AsType<float>()[IK{}]);
reg_bf16_small.template AsType<bhalf_t>()(k) = type_convert<bhalf_t, float>(
reg_f32.template AsType<float>()[IK{}] -
type_convert<float, bhalf_t>(reg_bf16_big.template AsType<bhalf_t>()[IK{}]));
});
}
/* */
// fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32;
@@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
}
};
/******************* tf32 *************************************/
/******************* tf32 on gfx942 *************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x8xf32;
@@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16>
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
#if defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
@@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
#if defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
@@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
}
};
/******************* tf32/xf32 on gfx950 ********************************/
/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */
/* step: */
/* 1. separate one input to 2 bf16 registers: */
/* in_bf16_big = f32_to_bf16(in_f32) */
/* in_bf16_small = in_f32 - in_bf16_big */
/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */
/* out_f32 = A_bf16_big * B_bf16_big */
/* out_f32 += A_bf16_small * B_bf16_big */
/* out_f32 += A_bf16_big * B_bf16_small */
/************************************************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32xf32;
template <>
struct intrin_mfma_f32_16x16x32xf32<16, 16>
{
template <class FloatC>
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
using I0 = Number<0>;
vector_type<float, 8> reg_a_v(reg_a);
vector_type<float, 8> reg_b_v(reg_b);
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
// Run 3 times: big*big, small*big, big*small
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16xf32;
template <>
struct intrin_mfma_f32_32x32x16xf32<32, 32>
{
template <class FloatC>
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
using I0 = Number<0>;
vector_type<float, 8> reg_a_v(reg_a);
vector_type<float, 8> reg_b_v(reg_b);
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
// Run 3 times: big*big, small*big, big*small
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
reg_c);
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
reg_c);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
/******************* tf32/xf32 on gfx950 end ************************************/
} // namespace ck

View File

@@ -12,6 +12,7 @@
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -254,4 +255,115 @@ CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& ad
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
}
namespace detail {
template <typename PREFIX = str_literal<>, typename SUFFIX = str_literal<>>
struct CK_PRINT_X_;
template <char... PREFIXChars, char... SUFFIXChars>
struct CK_PRINT_X_<str_literal<PREFIXChars...>, str_literal<SUFFIXChars...>>
{
template <typename T>
struct detail;
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
struct detail<
tensor_adaptor_coordinate<NDimHidden, BottomDimensionHiddenIds, TopDimensionHiddenIds>>
{
using coord_t =
tensor_adaptor_coordinate<NDimHidden, BottomDimensionHiddenIds, TopDimensionHiddenIds>;
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto get_hidden_format_i()
{
constexpr bool is_bottom =
sequence_any_of(BottomDimensionHiddenIds{}, [](auto b) { return b == I; });
constexpr bool is_top =
sequence_any_of(TopDimensionHiddenIds{}, [](auto t) { return t == I; });
constexpr auto d = make_str_literal("%d");
if constexpr(is_bottom && is_top)
return make_str_literal("_^") + d;
else if constexpr(is_bottom)
return make_str_literal("_") + d;
else if constexpr(is_top)
return make_str_literal("^") + d;
else
return d;
}
template <index_t N = NDimHidden>
CK_TILE_HOST_DEVICE static constexpr auto get_hidden_format()
{
constexpr auto sep = make_str_literal(" ");
if constexpr(N == 0)
return str_literal<>{};
else
return get_hidden_format<N - 1>() + sep + get_hidden_format_i<N - 1>();
}
CK_TILE_HOST_DEVICE static constexpr auto get_format()
{
constexpr auto d = make_str_literal("%d");
constexpr auto sep = make_str_literal(" ");
constexpr auto bottom_fmt =
d.template duplicate_n<BottomDimensionHiddenIds::size()>(sep);
constexpr auto top_fmt = d.template duplicate_n<TopDimensionHiddenIds::size()>(sep);
constexpr auto hidden_fmt = get_hidden_format();
return make_str_literal("[ __") + bottom_fmt + make_str_literal("__ | ^^") + top_fmt +
make_str_literal("^^ | ") + hidden_fmt + make_str_literal(" ]");
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_values()
{
return BottomDimensionHiddenIds::size() + TopDimensionHiddenIds::size() + NDimHidden;
}
CK_TILE_HOST_DEVICE static constexpr auto get_values(const coord_t& coord)
{
return container_concat(
coord.get_bottom_index(), coord.get_top_index(), coord.get_hidden_index());
}
};
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
constexpr auto fmt_tid = make_str_literal("tid %03d: ");
if constexpr(sizeof...(PREFIXChars) == 0)
return fmt_tid;
else
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
{
constexpr auto lf = make_str_literal("\n");
if constexpr(sizeof...(SUFFIXChars) == 0)
return lf;
else
return str_literal<SUFFIXChars...>{} + lf;
}
template <char... FMTChars, typename TArgs, index_t... Is, typename... Args>
CK_TILE_HOST_DEVICE void impl(str_literal<FMTChars...>,
const TArgs& targs,
std::integer_sequence<index_t, Is...>,
Args&&... args) const
{
constexpr auto fmt_wrap_v = get_prefix() + str_literal<FMTChars...>{} + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data, get_thread_id(), args..., targs.at(number<Is>())...);
#pragma clang diagnostic pop
}
template <typename T, typename... Args>
CK_TILE_HOST_DEVICE void operator()(T&& x, Args&&... args) const
{
using detail_t = detail<remove_cvref_t<T>>;
impl(detail_t::get_format(),
detail_t::get_values(std::forward<T>(x)),
std::make_integer_sequence<index_t, (detail_t::get_num_values())>{},
std::forward<Args>(args)...);
}
};
} // namespace detail
template <index_t N, typename B, typename T>
CK_TILE_HOST_DEVICE void print(const tensor_adaptor_coordinate<N, B, T>& coord)
{
detail::CK_PRINT_X_<>{}(coord);
}
} // namespace ck_tile

View File

@@ -89,4 +89,9 @@ CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc&
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
template <index_t N, typename T>
CK_TILE_HOST_DEVICE void print(const tensor_coordinate<N, T>& coord)
{
print(static_cast<typename tensor_coordinate<N, T>::Base>(coord));
}
} // namespace ck_tile

View File

@@ -7,6 +7,8 @@
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <auto... val>
@@ -18,48 +20,6 @@ template <typename... type>
{
}
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <index_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor;
@@ -79,17 +39,29 @@ struct CK_PRINTF<ConvertTo,
str_literal<SUFFIXChars...>>
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr auto default_format()
CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type()
{
if constexpr(std::is_same_v<T, float>)
return make_str_literal("%8.3f");
return std::make_tuple(make_str_literal("%8.3f"), T{});
else if constexpr(std::is_same_v<T, int>)
return make_str_literal("%5d");
return std::make_tuple(make_str_literal("%5d"), T{});
else if constexpr(std::is_same_v<T, unsigned int>)
return make_str_literal("%5u");
return std::make_tuple(make_str_literal("%5u"), T{});
else if constexpr(sizeof(T) == 1)
return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{});
else if constexpr(sizeof(T) == 2)
return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{});
else if constexpr(sizeof(T) == 4)
return std::make_tuple(make_str_literal("0x%08x"), uint32_t{});
else
return make_str_literal("0x%08x");
static_assert(false, "Unsupported type");
}
template <typename T>
using default_format_t =
std::remove_reference_t<decltype(std::get<0>(default_format_and_type<T>()))>;
template <typename T>
using default_type_t =
std::remove_reference_t<decltype(std::get<1>(default_format_and_type<T>()))>;
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
@@ -108,49 +80,58 @@ struct CK_PRINTF<ConvertTo,
return str_literal<SUFFIXChars...>{} + lf;
}
template <typename T, index_t N, typename Y, index_t... Is>
template <typename T, index_t N, typename Y, index_t... Is, typename... Args>
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
std::integer_sequence<index_t, Is...>) const
std::integer_sequence<index_t, Is...>,
Args&&... args) const
{
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
decltype(default_format<Y>()),
str_literal<FMTChars...>>;
using FMT1 = std::
conditional_t<sizeof...(FMTChars) == 0, default_format_t<Y>, str_literal<FMTChars...>>;
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
printf(fmt_wrap_v.data,
get_thread_id(),
N,
args...,
bit_cast<default_type_t<Y>>(type_convert<Y>(buf[Is]))...);
#pragma clang diagnostic pop
}
template <typename T, index_t N>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
template <typename T, index_t N, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf, Args&&... args) const
{
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
impl<T, N, ConvertTo_>(
buf, std::make_integer_sequence<index_t, N>{}, std::forward<Args>(args)...);
}
template <typename... TS>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
template <typename... TS, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor,
Args&&... args) const
{
return operator()(tensor.get_thread_buffer());
return operator()(tensor.get_thread_buffer(), std::forward<Args>(args)...);
}
};
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
template <typename T>
CK_TILE_HOST_DEVICE void print_warp0(T&& x)
{
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
if(get_thread_id() < get_warp_size())
print(std::forward<T>(x));
}
template <typename... Ts>
struct CK_PRINTF_WARP0 : public CK_PRINTF<Ts...>
{
using base_t = CK_PRINTF<Ts...>;
template <typename T>
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
template <typename T, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const
{
if(get_thread_id() < get_warp_size())
base_t::operator()(buf);
base_t::operator()(buf, std::forward<Args>(args)...);
}
};

View File

@@ -7,6 +7,51 @@
namespace ck_tile {
namespace str_literal_detail {
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
} // namespace str_literal_detail
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <size_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
str_literal_detail::makeTuple( \
std::make_index_sequence<str_literal_detail::constexpr_strlen(lit_)>()))
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>

View File

@@ -662,17 +662,21 @@ struct FlatmmKernel
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(
kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: splitk_batch_offset.splitted_k /
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
kargs.N / ScaleGranularityN),
make_tuple(ScaleGranularityKB == 0
? 1
: (splitk_batch_offset.splitted_k /
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});

View File

@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
@@ -92,7 +94,8 @@ struct ReferenceConvFwd : public device::BaseOperator
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
out_element_op_{out_element_op},
device_name_{ck::get_device_name()}
{
}
@@ -112,6 +115,7 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
::std::string device_name_; // the device which this conv is compared with
};
struct Invoker : public device::BaseInvoker
@@ -251,10 +255,39 @@ struct ReferenceConvFwd : public device::BaseOperator
x);
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
v_acc += ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_wei));
if(arg.device_name_ == "gfx942")
{
v_acc += ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_wei));
}
else if(arg.device_name_ == "gfx950")
{
ck::bhalf_t v_in_bf16_big =
ck::type_convert<ck::bhalf_t>(v_in);
ck::bhalf_t v_in_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_in - type_convert<float>(v_in_bf16_big));
ck::bhalf_t v_wei_bf16_big =
ck::type_convert<ck::bhalf_t>(v_wei);
ck::bhalf_t v_wei_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_wei - type_convert<float>(v_wei_bf16_big));
v_acc += ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_small) +
ck::type_convert<float>(v_in_bf16_small) *
ck::type_convert<float>(v_wei_bf16_big) +
ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_big);
}
else
{
throw std::runtime_error(
"Unsupported device: " + arg.device_name_ +
" for tf32 computation");
}
}
else
{
@@ -350,10 +383,41 @@ struct ReferenceConvFwd : public device::BaseOperator
x);
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
v_acc += ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_wei));
if(arg.device_name_ == "gfx942")
{
v_acc += ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_wei));
}
else if(arg.device_name_ == "gfx950")
{
ck::bhalf_t v_in_bf16_big =
ck::type_convert<ck::bhalf_t>(v_in);
ck::bhalf_t v_in_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_in - type_convert<float>(v_in_bf16_big));
ck::bhalf_t v_wei_bf16_big =
ck::type_convert<ck::bhalf_t>(v_wei);
ck::bhalf_t v_wei_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_wei -
type_convert<float>(v_wei_bf16_big));
v_acc +=
ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_small) +
ck::type_convert<float>(v_in_bf16_small) *
ck::type_convert<float>(v_wei_bf16_big) +
ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_big);
}
else
{
throw std::runtime_error(
"Unsupported device: " + arg.device_name_ +
" for tf32 computation");
}
}
else
{

View File

@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
@@ -45,7 +46,8 @@ struct ReferenceGemm : public device::BaseOperator
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
device_name_{ck::get_device_name()}
{
}
@@ -56,6 +58,7 @@ struct ReferenceGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
::std::string device_name_; // the device which this gemm is compared with
};
// Invoker
@@ -142,12 +145,37 @@ struct ReferenceGemm : public device::BaseOperator
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
is_same_v<ComputeTypeA, ck::tf32_t>)
{ // only for tf32 now
v_acc +=
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float> && is_same_v<AccDataType, float> &&
is_same_v<ComputeTypeA, ck::tf32_t> &&
is_same_v<ComputeTypeB, ck::tf32_t>)
{
if(arg.device_name_ == "gfx942")
{
v_acc +=
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
}
else if(arg.device_name_ == "gfx950")
{
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
ck::bhalf_t v_a_bf16_small = ck::type_convert<ck::bhalf_t>(
v_a - type_convert<float>(v_a_bf16_big));
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
ck::bhalf_t v_b_bf16_small = ck::type_convert<ck::bhalf_t>(
v_b - type_convert<float>(v_b_bf16_big));
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_small) +
ck::type_convert<AccDataType>(v_a_bf16_small) *
ck::type_convert<AccDataType>(v_b_bf16_big) +
ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_big);
}
else
{
throw std::runtime_error("Unsupported device: " + arg.device_name_);
}
}
else
{

View File

@@ -82,9 +82,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// multiply and accumulate
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
is_same_v<ComputeTypeA, ck::tf32_t>)
{ // only for tf32 now
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
{
#if defined(__gfx942__)
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
#elif defined(__gfx950__)
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
ck::bhalf_t v_a_bf16_small =
ck::type_convert<ck::bhalf_t>(v_a - type_convert<float>(v_a_bf16_big));
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
ck::bhalf_t v_b_bf16_small =
ck::type_convert<ck::bhalf_t>(v_b - type_convert<float>(v_b_bf16_big));
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_small) +
ck::type_convert<AccDataType>(v_a_bf16_small) *
ck::type_convert<AccDataType>(v_b_bf16_big) +
ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_big);
#else
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
#endif
}
else
{

View File

@@ -34,6 +34,30 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
// instances for double rate mfma on gfx950
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_dr = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 32, 32, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 32, 32, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
// clang-format on
>;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1 = std::tuple<
// clang-format off

View File

@@ -24,6 +24,13 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1<GemmDefault>{});
if(ck::get_device_name() == "gfx950")
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_dr<GemmDefault>{});
}
}
} // namespace instance

View File

@@ -24,6 +24,14 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1<GemmKPadding>{});
if(ck::get_device_name() == "gfx950")
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_dr<
GemmKPadding>{});
}
}
} // namespace instance

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -97,7 +97,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
break;
default:

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,3 +1,6 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

Some files were not shown because too many files have changed in this diff Show More