mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '0f8e33f81120e5734ef47a6a169ad85c6560cbd8' into develop
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
@@ -76,28 +76,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
|
||||
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
|
||||
// | | | | | | | | | | | Wave| Wave| Wave| |
|
||||
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1},
|
||||
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
|
||||
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
|
||||
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
|
||||
{ 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
|
||||
// Padded fallback kernel
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
|
||||
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1},
|
||||
// Irregular k
|
||||
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1},
|
||||
{ 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1},
|
||||
{ 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1},
|
||||
{ 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1},
|
||||
{ 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -200,28 +200,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1,16>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1,16>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// Padded fallback kernel
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// Irregular k
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
@@ -81,16 +81,16 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
|
||||
// | | | | | | | | Wave| Wave| Stage|
|
||||
// | | | | | | | | | | |
|
||||
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 16, 16, 4, 8, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 128, 128, 64, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 128, 64, 128, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, 1},
|
||||
{ 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, 1},
|
||||
// Irregular tile
|
||||
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
|
||||
{ 64, 32, 32, 32, 8, 8, 16, 16, 2, 2, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -194,14 +194,14 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 4>, 8},
|
||||
{ S<1, 16, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 4>, 4},
|
||||
{ S<1, 16, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// Irregular tile
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
// clang-format on
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include <iostream>
|
||||
@@ -55,12 +55,12 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
|
||||
// | | | | | | | | Wave| Wave| Stage|
|
||||
// | | | | | | | | | | |
|
||||
{ 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}
|
||||
{ 64, 64, 32, 32, 8, 8, 16, 16, 4, 2, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 64, 64, 64, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -116,11 +116,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 16>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 8>, 8}
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1, 8>, 4}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -223,8 +223,9 @@ extern "C" __global__ void run_${name}(
|
||||
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = DeviceConv::GridwiseGemm;
|
||||
|
||||
using GridwiseGemm = ck::conditional_t<ck::get_warp_size() == 64,
|
||||
typename DeviceConv::GridwiseGemm64,
|
||||
typename DeviceConv::GridwiseGemm32>;
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
|
||||
ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/utils.hpp"
|
||||
|
||||
@@ -13,7 +13,8 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
|
||||
|
||||
const std::unordered_set<std::string>& get_xdlop_archs()
|
||||
{
|
||||
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx942"};
|
||||
static std::unordered_set<std::string> supported_archs{
|
||||
"gfx90a", "gfx908", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"};
|
||||
return supported_archs;
|
||||
}
|
||||
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
Reference in New Issue
Block a user